From 77c98fd0427fec28ba2c89f5c7da086a75fa9e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 22 Aug 2022 20:20:56 +0800 Subject: [PATCH] Add v2ray WebSocket transport --- constant/v2ray.go | 3 +- inbound/vmess.go | 2 +- option/v2ray_transport.go | 41 +++++--- option/vmess.go | 24 ++--- test/vmess_transport_test.go | 48 ++++++--- transport/v2ray/grpc.go | 9 +- transport/v2ray/grpc_stub.go | 5 +- transport/v2ray/transport.go | 13 ++- transport/v2raygrpc/client.go | 5 +- transport/v2raygrpc/server.go | 5 +- transport/v2raywebsocket/client.go | 82 ++++++++++++++++ transport/v2raywebsocket/conn.go | 153 +++++++++++++++++++++++++++++ transport/v2raywebsocket/server.go | 138 ++++++++++++++++++++++++++ 13 files changed, 475 insertions(+), 53 deletions(-) create mode 100644 transport/v2raywebsocket/client.go create mode 100644 transport/v2raywebsocket/conn.go create mode 100644 transport/v2raywebsocket/server.go diff --git a/constant/v2ray.go b/constant/v2ray.go index 1c0bc395..cfffbb95 100644 --- a/constant/v2ray.go +++ b/constant/v2ray.go @@ -1,5 +1,6 @@ package constant const ( - V2RayTransportTypeGRPC = "grpc" + V2RayTransportTypeGRPC = "grpc" + V2RayTransportTypeWebsocket = "ws" ) diff --git a/inbound/vmess.go b/inbound/vmess.go index 3bf31602..0ee493e2 100644 --- a/inbound/vmess.go +++ b/inbound/vmess.go @@ -63,7 +63,7 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg } } if options.Transport != nil { - inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig.Config(), adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newTransportConnection, nil, nil)) + inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig.Config(), adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newTransportConnection, nil, nil), inbound) if err != nil { return nil, err } diff --git a/option/v2ray_transport.go b/option/v2ray_transport.go index 28135d37..17a98a58 100644 --- a/option/v2ray_transport.go +++ b/option/v2ray_transport.go @@ -6,28 +6,31 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -type _V2RayInboundTransportOptions struct { - Type string `json:"type,omitempty"` - GRPCOptions V2RayGRPCOptions `json:"-"` +type _V2RayTransportOptions struct { + Type string `json:"type,omitempty"` + GRPCOptions V2RayGRPCOptions `json:"-"` + WebsocketOptions V2RayWebsocketOptions `json:"-"` } -type V2RayInboundTransportOptions _V2RayOutboundTransportOptions +type V2RayTransportOptions _V2RayTransportOptions -func (o V2RayInboundTransportOptions) MarshalJSON() ([]byte, error) { +func (o V2RayTransportOptions) MarshalJSON() ([]byte, error) { var v any switch o.Type { case "": return nil, nil case C.V2RayTransportTypeGRPC: v = o.GRPCOptions + case C.V2RayTransportTypeWebsocket: + v = o.WebsocketOptions default: return nil, E.New("unknown transport type: " + o.Type) } - return MarshallObjects((_V2RayOutboundTransportOptions)(o), v) + return MarshallObjects((_V2RayTransportOptions)(o), v) } -func (o *V2RayInboundTransportOptions) UnmarshalJSON(bytes []byte) error { - err := json.Unmarshal(bytes, (*_V2RayOutboundTransportOptions)(o)) +func (o *V2RayTransportOptions) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_V2RayTransportOptions)(o)) if err != nil { return err } @@ -38,16 +41,17 @@ func (o *V2RayInboundTransportOptions) UnmarshalJSON(bytes []byte) error { default: return E.New("unknown transport type: " + o.Type) } - err = UnmarshallExcluded(bytes, (*_V2RayOutboundTransportOptions)(o), v) + err = UnmarshallExcluded(bytes, (*_V2RayTransportOptions)(o), v) if err != nil { return E.Cause(err, "vmess transport options") } return nil } -type _V2RayOutboundTransportOptions struct { - Type string `json:"type,omitempty"` - GRPCOptions V2RayGRPCOptions `json:"-"` +/*type _V2RayOutboundTransportOptions struct { + Type string `json:"type,omitempty"` + GRPCOptions V2RayGRPCOptions `json:"-"` + WebsocketOptions V2RayWebsocketOptions `json:"-"` } type V2RayOutboundTransportOptions _V2RayOutboundTransportOptions @@ -59,6 +63,8 @@ func (o V2RayOutboundTransportOptions) MarshalJSON() ([]byte, error) { return nil, nil case C.V2RayTransportTypeGRPC: v = o.GRPCOptions + case C.V2RayTransportTypeWebsocket: + v = o.WebsocketOptions default: return nil, E.New("unknown transport type: " + o.Type) } @@ -74,6 +80,8 @@ func (o *V2RayOutboundTransportOptions) UnmarshalJSON(bytes []byte) error { switch o.Type { case C.V2RayTransportTypeGRPC: v = &o.GRPCOptions + case C.V2RayTransportTypeWebsocket: + v = &o.WebsocketOptions default: return E.New("unknown transport type: " + o.Type) } @@ -82,8 +90,15 @@ func (o *V2RayOutboundTransportOptions) UnmarshalJSON(bytes []byte) error { return E.Cause(err, "vmess transport options") } return nil -} +}*/ type V2RayGRPCOptions struct { ServiceName string `json:"service_name,omitempty"` } + +type V2RayWebsocketOptions struct { + Path string `json:"path,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + MaxEarlyData uint32 `json:"max_early_data,omitempty"` + EarlyDataHeaderName string `json:"early_data_header_name,omitempty"` +} diff --git a/option/vmess.go b/option/vmess.go index 3dcd4d84..a2430506 100644 --- a/option/vmess.go +++ b/option/vmess.go @@ -2,9 +2,9 @@ package option type VMessInboundOptions struct { ListenOptions - Users []VMessUser `json:"users,omitempty"` - TLS *InboundTLSOptions `json:"tls,omitempty"` - Transport *V2RayInboundTransportOptions `json:"transport,omitempty"` + Users []VMessUser `json:"users,omitempty"` + TLS *InboundTLSOptions `json:"tls,omitempty"` + Transport *V2RayTransportOptions `json:"transport,omitempty"` } type VMessUser struct { @@ -16,13 +16,13 @@ type VMessUser struct { type VMessOutboundOptions struct { OutboundDialerOptions ServerOptions - UUID string `json:"uuid"` - Security string `json:"security"` - AlterId int `json:"alter_id,omitempty"` - GlobalPadding bool `json:"global_padding,omitempty"` - AuthenticatedLength bool `json:"authenticated_length,omitempty"` - Network NetworkList `json:"network,omitempty"` - TLS *OutboundTLSOptions `json:"tls,omitempty"` - Multiplex *MultiplexOptions `json:"multiplex,omitempty"` - Transport *V2RayOutboundTransportOptions `json:"transport,omitempty"` + UUID string `json:"uuid"` + Security string `json:"security"` + AlterId int `json:"alter_id,omitempty"` + GlobalPadding bool `json:"global_padding,omitempty"` + AuthenticatedLength bool `json:"authenticated_length,omitempty"` + Network NetworkList `json:"network,omitempty"` + TLS *OutboundTLSOptions `json:"tls,omitempty"` + Multiplex *MultiplexOptions `json:"multiplex,omitempty"` + Transport *V2RayTransportOptions `json:"transport,omitempty"` } diff --git a/test/vmess_transport_test.go b/test/vmess_transport_test.go index 55680c89..0f3681dd 100644 --- a/test/vmess_transport_test.go +++ b/test/vmess_transport_test.go @@ -12,6 +12,40 @@ import ( ) func TestVMessGRPCSelf(t *testing.T) { + testVMessWebscoketSelf(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + }, + }) +} + +func TestVMessWebscoketSelf(t *testing.T) { + t.Run("basic", func(t *testing.T) { + testVMessWebscoketSelf(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeWebsocket, + }) + }) + t.Run("v2ray early data", func(t *testing.T) { + testVMessWebscoketSelf(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeWebsocket, + WebsocketOptions: option.V2RayWebsocketOptions{ + MaxEarlyData: 2048, + }, + }) + }) + t.Run("xray early data", func(t *testing.T) { + testVMessWebscoketSelf(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeWebsocket, + WebsocketOptions: option.V2RayWebsocketOptions{ + MaxEarlyData: 2048, + EarlyDataHeaderName: "Sec-WebSocket-Protocol", + }, + }) + }) +} + +func testVMessWebscoketSelf(t *testing.T, transport *option.V2RayTransportOptions) { user, err := uuid.DefaultGenerator.NewV4() require.NoError(t, err) _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") @@ -50,12 +84,7 @@ func TestVMessGRPCSelf(t *testing.T) { CertificatePath: certPem, KeyPath: keyPem, }, - Transport: &option.V2RayInboundTransportOptions{ - Type: C.V2RayTransportTypeGRPC, - GRPCOptions: option.V2RayGRPCOptions{ - ServiceName: "TunService", - }, - }, + Transport: transport, }, }, }, @@ -78,12 +107,7 @@ func TestVMessGRPCSelf(t *testing.T) { ServerName: "example.org", CertificatePath: certPem, }, - Transport: &option.V2RayOutboundTransportOptions{ - Type: C.V2RayTransportTypeGRPC, - GRPCOptions: option.V2RayGRPCOptions{ - ServiceName: "TunService", - }, - }, + Transport: transport, }, }, }, diff --git a/transport/v2ray/grpc.go b/transport/v2ray/grpc.go index 11661f63..41e70bbe 100644 --- a/transport/v2ray/grpc.go +++ b/transport/v2ray/grpc.go @@ -7,15 +7,16 @@ import ( "crypto/tls" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2raygrpc" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func NewGRPCServer(ctx context.Context, serviceName string, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { - return v2raygrpc.NewServer(ctx, serviceName, tlsConfig, handler), nil +func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { + return v2raygrpc.NewServer(ctx, options, tlsConfig, handler), nil } -func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, serviceName string, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { - return v2raygrpc.NewClient(ctx, dialer, serverAddr, serviceName, tlsConfig), nil +func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { + return v2raygrpc.NewClient(ctx, dialer, serverAddr, options, tlsConfig), nil } diff --git a/transport/v2ray/grpc_stub.go b/transport/v2ray/grpc_stub.go index 29b5c740..971492f5 100644 --- a/transport/v2ray/grpc_stub.go +++ b/transport/v2ray/grpc_stub.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -14,10 +15,10 @@ import ( var errGRPCNotIncluded = E.New("gRPC is not included in this build, rebuild with -tags with_grpc") -func NewGRPCServer(ctx context.Context, serviceName string, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { +func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { return nil, errGRPCNotIncluded } -func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, serviceName string, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { +func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { return nil, errGRPCNotIncluded } diff --git a/transport/v2ray/transport.go b/transport/v2ray/transport.go index 1fb5c192..fcdf9071 100644 --- a/transport/v2ray/transport.go +++ b/transport/v2ray/transport.go @@ -7,30 +7,35 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/v2raywebsocket" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func NewServerTransport(ctx context.Context, options option.V2RayInboundTransportOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { +func NewServerTransport(ctx context.Context, options option.V2RayTransportOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { if options.Type == "" { return nil, nil } switch options.Type { case C.V2RayTransportTypeGRPC: - return NewGRPCServer(ctx, options.GRPCOptions.ServiceName, tlsConfig, handler) + return NewGRPCServer(ctx, options.GRPCOptions, tlsConfig, handler) + case C.V2RayTransportTypeWebsocket: + return v2raywebsocket.NewServer(ctx, options.WebsocketOptions, tlsConfig, handler, errorHandler), nil default: return nil, E.New("unknown transport type: " + options.Type) } } -func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayOutboundTransportOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { +func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayTransportOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { if options.Type == "" { return nil, nil } switch options.Type { case C.V2RayTransportTypeGRPC: - return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions.ServiceName, tlsConfig) + return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions, tlsConfig) + case C.V2RayTransportTypeWebsocket: + return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig), nil default: return nil, E.New("unknown transport type: " + options.Type) } diff --git a/transport/v2raygrpc/client.go b/transport/v2raygrpc/client.go index 74269de5..f10c607b 100644 --- a/transport/v2raygrpc/client.go +++ b/transport/v2raygrpc/client.go @@ -8,6 +8,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -31,7 +32,7 @@ type Client struct { connAccess sync.Mutex } -func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, serviceName string, tlsConfig *tls.Config) adapter.V2RayClientTransport { +func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) adapter.V2RayClientTransport { var dialOptions []grpc.DialOption if tlsConfig != nil { dialOptions = append(dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) @@ -55,7 +56,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, ser ctx: ctx, dialer: dialer, serverAddr: serverAddr.String(), - serviceName: serviceName, + serviceName: options.ServiceName, dialOptions: dialOptions, } } diff --git a/transport/v2raygrpc/server.go b/transport/v2raygrpc/server.go index 041b5464..ddae67a0 100644 --- a/transport/v2raygrpc/server.go +++ b/transport/v2raygrpc/server.go @@ -6,6 +6,7 @@ import ( "net" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -21,14 +22,14 @@ type Server struct { server *grpc.Server } -func NewServer(ctx context.Context, serviceName string, tlsConfig *tls.Config, handler N.TCPConnectionHandler) *Server { +func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) *Server { var serverOptions []grpc.ServerOption if tlsConfig != nil { tlsConfig.NextProtos = []string{"h2"} serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig))) } server := &Server{ctx, handler, grpc.NewServer(serverOptions...)} - RegisterGunServiceCustomNameServer(server.server, server, serviceName) + RegisterGunServiceCustomNameServer(server.server, server, options.ServiceName) return server } diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go new file mode 100644 index 00000000..bc24481a --- /dev/null +++ b/transport/v2raywebsocket/client.go @@ -0,0 +1,82 @@ +package v2raywebsocket + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/gorilla/websocket" +) + +var _ adapter.V2RayClientTransport = (*Client)(nil) + +type Client struct { + dialer *websocket.Dialer + uri string + headers http.Header + maxEarlyData uint32 + earlyDataHeaderName string +} + +func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig *tls.Config) adapter.V2RayClientTransport { + wsDialer := &websocket.Dialer{ + NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSClientConfig: tlsConfig, + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: time.Second * 8, + } + var uri url.URL + if tlsConfig == nil { + uri.Scheme = "ws" + } else { + uri.Scheme = "wss" + } + uri.Host = serverAddr.String() + uri.Path = options.Path + if !strings.HasPrefix(uri.Path, "/") { + uri.Path = "/" + uri.Path + } + headers := make(http.Header) + for key, value := range options.Headers { + headers.Set(key, value) + } + return &Client{ + wsDialer, + uri.String(), + headers, + options.MaxEarlyData, + options.EarlyDataHeaderName, + } +} + +func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { + if c.maxEarlyData <= 0 { + conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers) + if err == nil { + return &WebsocketConn{Conn: conn}, nil + } + return nil, wrapDialError(response, err) + } else { + return &EarlyWebsocketConn{Client: c, create: make(chan struct{})}, nil + } +} + +func wrapDialError(response *http.Response, err error) error { + if response == nil { + return err + } + return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status) +} diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go new file mode 100644 index 00000000..213de2c5 --- /dev/null +++ b/transport/v2raywebsocket/conn.go @@ -0,0 +1,153 @@ +package v2raywebsocket + +import ( + "encoding/base64" + "io" + "net" + "net/http" + "os" + "time" + + E "github.com/sagernet/sing/common/exceptions" + + "github.com/gorilla/websocket" +) + +type WebsocketConn struct { + *websocket.Conn + remoteAddr net.Addr + reader io.Reader +} + +func (c *WebsocketConn) Read(b []byte) (n int, err error) { + for { + if c.reader == nil { + _, c.reader, err = c.NextReader() + if err != nil { + return + } + } + n, err = c.reader.Read(b) + if E.IsMulti(err, io.EOF) { + c.reader = nil + continue + } + return + } +} + +func (c *WebsocketConn) Write(b []byte) (n int, err error) { + err = c.WriteMessage(websocket.BinaryMessage, b) + if err != nil { + return + } + return len(b), nil +} + +func (c *WebsocketConn) RemoteAddr() net.Addr { + if c.remoteAddr != nil { + return c.remoteAddr + } + return c.Conn.RemoteAddr() +} + +func (c *WebsocketConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +type EarlyWebsocketConn struct { + *Client + conn *WebsocketConn + create chan struct{} +} + +func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { + if c.conn == nil { + <-c.create + } + return c.conn.Read(b) +} + +func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { + if c.conn != nil { + return c.conn.Write(b) + } + var ( + earlyData []byte + lateData []byte + conn *websocket.Conn + response *http.Response + ) + if len(earlyData) > int(c.maxEarlyData) { + earlyData = earlyData[:c.maxEarlyData] + lateData = lateData[c.maxEarlyData:] + } else { + earlyData = b + } + if len(earlyData) > 0 { + earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) + if c.earlyDataHeaderName == "" { + conn, response, err = c.dialer.Dial(c.uri+earlyDataString, c.headers) + } else { + headers := c.headers.Clone() + headers.Set(c.earlyDataHeaderName, earlyDataString) + conn, response, err = c.dialer.Dial(c.uri, headers) + } + } else { + conn, response, err = c.dialer.Dial(c.uri, c.headers) + } + if err != nil { + return 0, wrapDialError(response, err) + } + c.conn = &WebsocketConn{Conn: conn} + close(c.create) + if len(lateData) > 0 { + _, err = c.conn.Write(lateData) + } + if err != nil { + return + } + return len(b), nil +} + +func (c *EarlyWebsocketConn) Close() error { + if c.conn == nil { + return nil + } + return c.conn.Close() +} + +func (c *EarlyWebsocketConn) LocalAddr() net.Addr { + if c.conn == nil { + return nil + } + return c.conn.LocalAddr() +} + +func (c *EarlyWebsocketConn) RemoteAddr() net.Addr { + if c.conn == nil { + return nil + } + return c.conn.RemoteAddr() +} + +func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error { + if c.conn == nil { + return os.ErrInvalid + } + return c.conn.SetDeadline(t) +} + +func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error { + if c.conn == nil { + return os.ErrInvalid + } + return c.conn.SetReadDeadline(t) +} + +func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error { + if c.conn == nil { + return os.ErrInvalid + } + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/v2raywebsocket/server.go b/transport/v2raywebsocket/server.go new file mode 100644 index 00000000..d8239a3c --- /dev/null +++ b/transport/v2raywebsocket/server.go @@ -0,0 +1,138 @@ +package v2raywebsocket + +import ( + "context" + "crypto/tls" + "encoding/base64" + "net" + "net/http" + "net/netip" + "strings" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/gorilla/websocket" +) + +var _ adapter.V2RayServerTransport = (*Server)(nil) + +type Server struct { + ctx context.Context + handler N.TCPConnectionHandler + errorHandler E.Handler + httpServer *http.Server + path string + maxEarlyData uint32 + earlyDataHeaderName string +} + +func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) *Server { + server := &Server{ + ctx: ctx, + handler: handler, + errorHandler: errorHandler, + path: options.Path, + maxEarlyData: options.MaxEarlyData, + earlyDataHeaderName: options.EarlyDataHeaderName, + } + if !strings.HasPrefix(server.path, "/") { + server.path = "/" + server.path + } + server.httpServer = &http.Server{ + Handler: server, + ReadHeaderTimeout: C.TCPTimeout, + MaxHeaderBytes: http.DefaultMaxHeaderBytes, + TLSConfig: tlsConfig, + } + return server +} + +var upgrader = websocket.Upgrader{ + HandshakeTimeout: C.TCPTimeout, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" { + if request.URL.Path != s.path { + writer.WriteHeader(http.StatusNotFound) + s.badRequest(request, E.New("bad path: ", request.URL.Path)) + return + } + } + var ( + earlyData []byte + err error + conn net.Conn + ) + if s.earlyDataHeaderName == "" { + if strings.HasPrefix(request.URL.RequestURI(), s.path) { + earlyDataStr := request.URL.RequestURI()[len(s.path):] + earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) + } else { + writer.WriteHeader(http.StatusNotFound) + s.badRequest(request, E.New("bad path: ", request.URL.Path)) + return + } + } else { + earlyDataStr := request.Header.Get(s.earlyDataHeaderName) + if earlyDataStr != "" { + earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) + } + } + if err != nil { + writer.WriteHeader(http.StatusBadRequest) + s.badRequest(request, E.Cause(err, "decode early data")) + return + } + wsConn, err := upgrader.Upgrade(writer, request, nil) + if err != nil { + s.badRequest(request, E.Cause(err, "upgrade websocket connection")) + return + } + var remoteAddr net.Addr + forwardFrom := request.Header.Get("X-Forwarded-For") + if forwardFrom != "" { + for _, from := range strings.Split(forwardFrom, ",") { + originAddr, err := netip.ParseAddr(from) + if err == nil { + remoteAddr = M.SocksaddrFrom(originAddr, 0).TCPAddr() + break + } + } + } + conn = &WebsocketConn{ + Conn: wsConn, + remoteAddr: remoteAddr, + } + if len(earlyData) > 0 { + conn = bufio.NewCachedConn(conn, buf.As(earlyData)) + } + s.handler.NewConnection(request.Context(), conn, M.Metadata{}) +} + +func (s *Server) badRequest(request *http.Request, err error) { + s.errorHandler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) +} + +func (s *Server) Serve(listener net.Listener) error { + if s.httpServer.TLSConfig == nil { + return s.httpServer.Serve(listener) + } else { + return s.httpServer.ServeTLS(listener, "", "") + } +} + +func (s *Server) Close() error { + return common.Close(s.httpServer) +}