Add v2ray WebSocket transport

This commit is contained in:
世界 2022-08-22 20:20:56 +08:00
parent 082872b2f3
commit 77c98fd042
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 475 additions and 53 deletions

View file

@ -1,5 +1,6 @@
package constant package constant
const ( const (
V2RayTransportTypeGRPC = "grpc" V2RayTransportTypeGRPC = "grpc"
V2RayTransportTypeWebsocket = "ws"
) )

View file

@ -63,7 +63,7 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg
} }
} }
if options.Transport != nil { if options.Transport != nil {
inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig.Config(), adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newTransportConnection, nil, nil)) inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig.Config(), adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newTransportConnection, nil, nil), inbound)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,28 +6,31 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
type _V2RayInboundTransportOptions struct { type _V2RayTransportOptions struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
GRPCOptions V2RayGRPCOptions `json:"-"` GRPCOptions V2RayGRPCOptions `json:"-"`
WebsocketOptions V2RayWebsocketOptions `json:"-"`
} }
type V2RayInboundTransportOptions _V2RayOutboundTransportOptions type V2RayTransportOptions _V2RayTransportOptions
func (o V2RayInboundTransportOptions) MarshalJSON() ([]byte, error) { func (o V2RayTransportOptions) MarshalJSON() ([]byte, error) {
var v any var v any
switch o.Type { switch o.Type {
case "": case "":
return nil, nil return nil, nil
case C.V2RayTransportTypeGRPC: case C.V2RayTransportTypeGRPC:
v = o.GRPCOptions v = o.GRPCOptions
case C.V2RayTransportTypeWebsocket:
v = o.WebsocketOptions
default: default:
return nil, E.New("unknown transport type: " + o.Type) return nil, E.New("unknown transport type: " + o.Type)
} }
return MarshallObjects((_V2RayOutboundTransportOptions)(o), v) return MarshallObjects((_V2RayTransportOptions)(o), v)
} }
func (o *V2RayInboundTransportOptions) UnmarshalJSON(bytes []byte) error { func (o *V2RayTransportOptions) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_V2RayOutboundTransportOptions)(o)) err := json.Unmarshal(bytes, (*_V2RayTransportOptions)(o))
if err != nil { if err != nil {
return err return err
} }
@ -38,16 +41,17 @@ func (o *V2RayInboundTransportOptions) UnmarshalJSON(bytes []byte) error {
default: default:
return E.New("unknown transport type: " + o.Type) return E.New("unknown transport type: " + o.Type)
} }
err = UnmarshallExcluded(bytes, (*_V2RayOutboundTransportOptions)(o), v) err = UnmarshallExcluded(bytes, (*_V2RayTransportOptions)(o), v)
if err != nil { if err != nil {
return E.Cause(err, "vmess transport options") return E.Cause(err, "vmess transport options")
} }
return nil return nil
} }
type _V2RayOutboundTransportOptions struct { /*type _V2RayOutboundTransportOptions struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
GRPCOptions V2RayGRPCOptions `json:"-"` GRPCOptions V2RayGRPCOptions `json:"-"`
WebsocketOptions V2RayWebsocketOptions `json:"-"`
} }
type V2RayOutboundTransportOptions _V2RayOutboundTransportOptions type V2RayOutboundTransportOptions _V2RayOutboundTransportOptions
@ -59,6 +63,8 @@ func (o V2RayOutboundTransportOptions) MarshalJSON() ([]byte, error) {
return nil, nil return nil, nil
case C.V2RayTransportTypeGRPC: case C.V2RayTransportTypeGRPC:
v = o.GRPCOptions v = o.GRPCOptions
case C.V2RayTransportTypeWebsocket:
v = o.WebsocketOptions
default: default:
return nil, E.New("unknown transport type: " + o.Type) return nil, E.New("unknown transport type: " + o.Type)
} }
@ -74,6 +80,8 @@ func (o *V2RayOutboundTransportOptions) UnmarshalJSON(bytes []byte) error {
switch o.Type { switch o.Type {
case C.V2RayTransportTypeGRPC: case C.V2RayTransportTypeGRPC:
v = &o.GRPCOptions v = &o.GRPCOptions
case C.V2RayTransportTypeWebsocket:
v = &o.WebsocketOptions
default: default:
return E.New("unknown transport type: " + o.Type) return E.New("unknown transport type: " + o.Type)
} }
@ -82,8 +90,15 @@ func (o *V2RayOutboundTransportOptions) UnmarshalJSON(bytes []byte) error {
return E.Cause(err, "vmess transport options") return E.Cause(err, "vmess transport options")
} }
return nil return nil
} }*/
type V2RayGRPCOptions struct { type V2RayGRPCOptions struct {
ServiceName string `json:"service_name,omitempty"` ServiceName string `json:"service_name,omitempty"`
} }
type V2RayWebsocketOptions struct {
Path string `json:"path,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
MaxEarlyData uint32 `json:"max_early_data,omitempty"`
EarlyDataHeaderName string `json:"early_data_header_name,omitempty"`
}

View file

@ -2,9 +2,9 @@ package option
type VMessInboundOptions struct { type VMessInboundOptions struct {
ListenOptions ListenOptions
Users []VMessUser `json:"users,omitempty"` Users []VMessUser `json:"users,omitempty"`
TLS *InboundTLSOptions `json:"tls,omitempty"` TLS *InboundTLSOptions `json:"tls,omitempty"`
Transport *V2RayInboundTransportOptions `json:"transport,omitempty"` Transport *V2RayTransportOptions `json:"transport,omitempty"`
} }
type VMessUser struct { type VMessUser struct {
@ -16,13 +16,13 @@ type VMessUser struct {
type VMessOutboundOptions struct { type VMessOutboundOptions struct {
OutboundDialerOptions OutboundDialerOptions
ServerOptions ServerOptions
UUID string `json:"uuid"` UUID string `json:"uuid"`
Security string `json:"security"` Security string `json:"security"`
AlterId int `json:"alter_id,omitempty"` AlterId int `json:"alter_id,omitempty"`
GlobalPadding bool `json:"global_padding,omitempty"` GlobalPadding bool `json:"global_padding,omitempty"`
AuthenticatedLength bool `json:"authenticated_length,omitempty"` AuthenticatedLength bool `json:"authenticated_length,omitempty"`
Network NetworkList `json:"network,omitempty"` Network NetworkList `json:"network,omitempty"`
TLS *OutboundTLSOptions `json:"tls,omitempty"` TLS *OutboundTLSOptions `json:"tls,omitempty"`
Multiplex *MultiplexOptions `json:"multiplex,omitempty"` Multiplex *MultiplexOptions `json:"multiplex,omitempty"`
Transport *V2RayOutboundTransportOptions `json:"transport,omitempty"` Transport *V2RayTransportOptions `json:"transport,omitempty"`
} }

View file

@ -12,6 +12,40 @@ import (
) )
func TestVMessGRPCSelf(t *testing.T) { func TestVMessGRPCSelf(t *testing.T) {
testVMessWebscoketSelf(t, &option.V2RayTransportOptions{
Type: C.V2RayTransportTypeGRPC,
GRPCOptions: option.V2RayGRPCOptions{
ServiceName: "TunService",
},
})
}
func TestVMessWebscoketSelf(t *testing.T) {
t.Run("basic", func(t *testing.T) {
testVMessWebscoketSelf(t, &option.V2RayTransportOptions{
Type: C.V2RayTransportTypeWebsocket,
})
})
t.Run("v2ray early data", func(t *testing.T) {
testVMessWebscoketSelf(t, &option.V2RayTransportOptions{
Type: C.V2RayTransportTypeWebsocket,
WebsocketOptions: option.V2RayWebsocketOptions{
MaxEarlyData: 2048,
},
})
})
t.Run("xray early data", func(t *testing.T) {
testVMessWebscoketSelf(t, &option.V2RayTransportOptions{
Type: C.V2RayTransportTypeWebsocket,
WebsocketOptions: option.V2RayWebsocketOptions{
MaxEarlyData: 2048,
EarlyDataHeaderName: "Sec-WebSocket-Protocol",
},
})
})
}
func testVMessWebscoketSelf(t *testing.T, transport *option.V2RayTransportOptions) {
user, err := uuid.DefaultGenerator.NewV4() user, err := uuid.DefaultGenerator.NewV4()
require.NoError(t, err) require.NoError(t, err)
_, certPem, keyPem := createSelfSignedCertificate(t, "example.org") _, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
@ -50,12 +84,7 @@ func TestVMessGRPCSelf(t *testing.T) {
CertificatePath: certPem, CertificatePath: certPem,
KeyPath: keyPem, KeyPath: keyPem,
}, },
Transport: &option.V2RayInboundTransportOptions{ Transport: transport,
Type: C.V2RayTransportTypeGRPC,
GRPCOptions: option.V2RayGRPCOptions{
ServiceName: "TunService",
},
},
}, },
}, },
}, },
@ -78,12 +107,7 @@ func TestVMessGRPCSelf(t *testing.T) {
ServerName: "example.org", ServerName: "example.org",
CertificatePath: certPem, CertificatePath: certPem,
}, },
Transport: &option.V2RayOutboundTransportOptions{ Transport: transport,
Type: C.V2RayTransportTypeGRPC,
GRPCOptions: option.V2RayGRPCOptions{
ServiceName: "TunService",
},
},
}, },
}, },
}, },

View file

@ -7,15 +7,16 @@ import (
"crypto/tls" "crypto/tls"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/v2raygrpc" "github.com/sagernet/sing-box/transport/v2raygrpc"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
func NewGRPCServer(ctx context.Context, serviceName string, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) {
return v2raygrpc.NewServer(ctx, serviceName, tlsConfig, handler), nil return v2raygrpc.NewServer(ctx, options, tlsConfig, handler), nil
} }
func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, serviceName string, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) {
return v2raygrpc.NewClient(ctx, dialer, serverAddr, serviceName, tlsConfig), nil return v2raygrpc.NewClient(ctx, dialer, serverAddr, options, tlsConfig), nil
} }

View file

@ -7,6 +7,7 @@ import (
"crypto/tls" "crypto/tls"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -14,10 +15,10 @@ import (
var errGRPCNotIncluded = E.New("gRPC is not included in this build, rebuild with -tags with_grpc") var errGRPCNotIncluded = E.New("gRPC is not included in this build, rebuild with -tags with_grpc")
func NewGRPCServer(ctx context.Context, serviceName string, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) {
return nil, errGRPCNotIncluded return nil, errGRPCNotIncluded
} }
func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, serviceName string, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) {
return nil, errGRPCNotIncluded return nil, errGRPCNotIncluded
} }

View file

@ -7,30 +7,35 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/v2raywebsocket"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
func NewServerTransport(ctx context.Context, options option.V2RayInboundTransportOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { func NewServerTransport(ctx context.Context, options option.V2RayTransportOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) {
if options.Type == "" { if options.Type == "" {
return nil, nil return nil, nil
} }
switch options.Type { switch options.Type {
case C.V2RayTransportTypeGRPC: case C.V2RayTransportTypeGRPC:
return NewGRPCServer(ctx, options.GRPCOptions.ServiceName, tlsConfig, handler) return NewGRPCServer(ctx, options.GRPCOptions, tlsConfig, handler)
case C.V2RayTransportTypeWebsocket:
return v2raywebsocket.NewServer(ctx, options.WebsocketOptions, tlsConfig, handler, errorHandler), nil
default: default:
return nil, E.New("unknown transport type: " + options.Type) return nil, E.New("unknown transport type: " + options.Type)
} }
} }
func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayOutboundTransportOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayTransportOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) {
if options.Type == "" { if options.Type == "" {
return nil, nil return nil, nil
} }
switch options.Type { switch options.Type {
case C.V2RayTransportTypeGRPC: case C.V2RayTransportTypeGRPC:
return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions.ServiceName, tlsConfig) return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions, tlsConfig)
case C.V2RayTransportTypeWebsocket:
return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig), nil
default: default:
return nil, E.New("unknown transport type: " + options.Type) return nil, E.New("unknown transport type: " + options.Type)
} }

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -31,7 +32,7 @@ type Client struct {
connAccess sync.Mutex connAccess sync.Mutex
} }
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, serviceName string, tlsConfig *tls.Config) adapter.V2RayClientTransport { func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) adapter.V2RayClientTransport {
var dialOptions []grpc.DialOption var dialOptions []grpc.DialOption
if tlsConfig != nil { if tlsConfig != nil {
dialOptions = append(dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) dialOptions = append(dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
@ -55,7 +56,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, ser
ctx: ctx, ctx: ctx,
dialer: dialer, dialer: dialer,
serverAddr: serverAddr.String(), serverAddr: serverAddr.String(),
serviceName: serviceName, serviceName: options.ServiceName,
dialOptions: dialOptions, dialOptions: dialOptions,
} }
} }

View file

@ -6,6 +6,7 @@ import (
"net" "net"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -21,14 +22,14 @@ type Server struct {
server *grpc.Server server *grpc.Server
} }
func NewServer(ctx context.Context, serviceName string, tlsConfig *tls.Config, handler N.TCPConnectionHandler) *Server { func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) *Server {
var serverOptions []grpc.ServerOption var serverOptions []grpc.ServerOption
if tlsConfig != nil { if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2"} tlsConfig.NextProtos = []string{"h2"}
serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig))) serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
} }
server := &Server{ctx, handler, grpc.NewServer(serverOptions...)} server := &Server{ctx, handler, grpc.NewServer(serverOptions...)}
RegisterGunServiceCustomNameServer(server.server, server, serviceName) RegisterGunServiceCustomNameServer(server.server, server, options.ServiceName)
return server return server
} }

View file

@ -0,0 +1,82 @@
package v2raywebsocket
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/gorilla/websocket"
)
var _ adapter.V2RayClientTransport = (*Client)(nil)
type Client struct {
dialer *websocket.Dialer
uri string
headers http.Header
maxEarlyData uint32
earlyDataHeaderName string
}
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig *tls.Config) adapter.V2RayClientTransport {
wsDialer := &websocket.Dialer{
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: tlsConfig,
ReadBufferSize: 4 * 1024,
WriteBufferSize: 4 * 1024,
HandshakeTimeout: time.Second * 8,
}
var uri url.URL
if tlsConfig == nil {
uri.Scheme = "ws"
} else {
uri.Scheme = "wss"
}
uri.Host = serverAddr.String()
uri.Path = options.Path
if !strings.HasPrefix(uri.Path, "/") {
uri.Path = "/" + uri.Path
}
headers := make(http.Header)
for key, value := range options.Headers {
headers.Set(key, value)
}
return &Client{
wsDialer,
uri.String(),
headers,
options.MaxEarlyData,
options.EarlyDataHeaderName,
}
}
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
if c.maxEarlyData <= 0 {
conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
if err == nil {
return &WebsocketConn{Conn: conn}, nil
}
return nil, wrapDialError(response, err)
} else {
return &EarlyWebsocketConn{Client: c, create: make(chan struct{})}, nil
}
}
func wrapDialError(response *http.Response, err error) error {
if response == nil {
return err
}
return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
}

View file

@ -0,0 +1,153 @@
package v2raywebsocket
import (
"encoding/base64"
"io"
"net"
"net/http"
"os"
"time"
E "github.com/sagernet/sing/common/exceptions"
"github.com/gorilla/websocket"
)
type WebsocketConn struct {
*websocket.Conn
remoteAddr net.Addr
reader io.Reader
}
func (c *WebsocketConn) Read(b []byte) (n int, err error) {
for {
if c.reader == nil {
_, c.reader, err = c.NextReader()
if err != nil {
return
}
}
n, err = c.reader.Read(b)
if E.IsMulti(err, io.EOF) {
c.reader = nil
continue
}
return
}
}
func (c *WebsocketConn) Write(b []byte) (n int, err error) {
err = c.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
return
}
return len(b), nil
}
func (c *WebsocketConn) RemoteAddr() net.Addr {
if c.remoteAddr != nil {
return c.remoteAddr
}
return c.Conn.RemoteAddr()
}
func (c *WebsocketConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
type EarlyWebsocketConn struct {
*Client
conn *WebsocketConn
create chan struct{}
}
func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
if c.conn == nil {
<-c.create
}
return c.conn.Read(b)
}
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
if c.conn != nil {
return c.conn.Write(b)
}
var (
earlyData []byte
lateData []byte
conn *websocket.Conn
response *http.Response
)
if len(earlyData) > int(c.maxEarlyData) {
earlyData = earlyData[:c.maxEarlyData]
lateData = lateData[c.maxEarlyData:]
} else {
earlyData = b
}
if len(earlyData) > 0 {
earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
if c.earlyDataHeaderName == "" {
conn, response, err = c.dialer.Dial(c.uri+earlyDataString, c.headers)
} else {
headers := c.headers.Clone()
headers.Set(c.earlyDataHeaderName, earlyDataString)
conn, response, err = c.dialer.Dial(c.uri, headers)
}
} else {
conn, response, err = c.dialer.Dial(c.uri, c.headers)
}
if err != nil {
return 0, wrapDialError(response, err)
}
c.conn = &WebsocketConn{Conn: conn}
close(c.create)
if len(lateData) > 0 {
_, err = c.conn.Write(lateData)
}
if err != nil {
return
}
return len(b), nil
}
func (c *EarlyWebsocketConn) Close() error {
if c.conn == nil {
return nil
}
return c.conn.Close()
}
func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
if c.conn == nil {
return nil
}
return c.conn.LocalAddr()
}
func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
if c.conn == nil {
return nil
}
return c.conn.RemoteAddr()
}
func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
if c.conn == nil {
return os.ErrInvalid
}
return c.conn.SetDeadline(t)
}
func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
if c.conn == nil {
return os.ErrInvalid
}
return c.conn.SetReadDeadline(t)
}
func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
if c.conn == nil {
return os.ErrInvalid
}
return c.conn.SetWriteDeadline(t)
}

View file

@ -0,0 +1,138 @@
package v2raywebsocket
import (
"context"
"crypto/tls"
"encoding/base64"
"net"
"net/http"
"net/netip"
"strings"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/gorilla/websocket"
)
var _ adapter.V2RayServerTransport = (*Server)(nil)
type Server struct {
ctx context.Context
handler N.TCPConnectionHandler
errorHandler E.Handler
httpServer *http.Server
path string
maxEarlyData uint32
earlyDataHeaderName string
}
func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) *Server {
server := &Server{
ctx: ctx,
handler: handler,
errorHandler: errorHandler,
path: options.Path,
maxEarlyData: options.MaxEarlyData,
earlyDataHeaderName: options.EarlyDataHeaderName,
}
if !strings.HasPrefix(server.path, "/") {
server.path = "/" + server.path
}
server.httpServer = &http.Server{
Handler: server,
ReadHeaderTimeout: C.TCPTimeout,
MaxHeaderBytes: http.DefaultMaxHeaderBytes,
TLSConfig: tlsConfig,
}
return server
}
var upgrader = websocket.Upgrader{
HandshakeTimeout: C.TCPTimeout,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" {
if request.URL.Path != s.path {
writer.WriteHeader(http.StatusNotFound)
s.badRequest(request, E.New("bad path: ", request.URL.Path))
return
}
}
var (
earlyData []byte
err error
conn net.Conn
)
if s.earlyDataHeaderName == "" {
if strings.HasPrefix(request.URL.RequestURI(), s.path) {
earlyDataStr := request.URL.RequestURI()[len(s.path):]
earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr)
} else {
writer.WriteHeader(http.StatusNotFound)
s.badRequest(request, E.New("bad path: ", request.URL.Path))
return
}
} else {
earlyDataStr := request.Header.Get(s.earlyDataHeaderName)
if earlyDataStr != "" {
earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr)
}
}
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
s.badRequest(request, E.Cause(err, "decode early data"))
return
}
wsConn, err := upgrader.Upgrade(writer, request, nil)
if err != nil {
s.badRequest(request, E.Cause(err, "upgrade websocket connection"))
return
}
var remoteAddr net.Addr
forwardFrom := request.Header.Get("X-Forwarded-For")
if forwardFrom != "" {
for _, from := range strings.Split(forwardFrom, ",") {
originAddr, err := netip.ParseAddr(from)
if err == nil {
remoteAddr = M.SocksaddrFrom(originAddr, 0).TCPAddr()
break
}
}
}
conn = &WebsocketConn{
Conn: wsConn,
remoteAddr: remoteAddr,
}
if len(earlyData) > 0 {
conn = bufio.NewCachedConn(conn, buf.As(earlyData))
}
s.handler.NewConnection(request.Context(), conn, M.Metadata{})
}
func (s *Server) badRequest(request *http.Request, err error) {
s.errorHandler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr))
}
func (s *Server) Serve(listener net.Listener) error {
if s.httpServer.TLSConfig == nil {
return s.httpServer.Serve(listener)
} else {
return s.httpServer.ServeTLS(listener, "", "")
}
}
func (s *Server) Close() error {
return common.Close(s.httpServer)
}