Add WebSocket 0-RTT support (#375)

This commit is contained in:
RPRX 2021-03-14 07:10:10 +00:00 committed by GitHub
parent 9adce5a6c4
commit 60b06877bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 122 additions and 14 deletions

View file

@ -3,6 +3,8 @@ package conf
import ( import (
"encoding/json" "encoding/json"
"math" "math"
"net/url"
"strconv"
"strings" "strings"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -155,9 +157,20 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
Value: value, Value: value,
}) })
} }
var ed uint32
if u, err := url.Parse(path); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
path = u.String()
}
}
config := &websocket.Config{ config := &websocket.Config{
Path: path, Path: path,
Header: header, Header: header,
Ed: ed,
} }
if c.AcceptProxyProtocol { if c.AcceptProxyProtocol {
config.AcceptProxyProtocol = c.AcceptProxyProtocol config.AcceptProxyProtocol = c.AcceptProxyProtocol

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.25.0
// protoc v3.14.0 // protoc v3.15.6
// source: transport/internet/websocket/config.proto // source: transport/internet/websocket/config.proto
package websocket package websocket
@ -89,6 +89,7 @@ type Config struct {
Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"` Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"`
Header []*Header `protobuf:"bytes,3,rep,name=header,proto3" json:"header,omitempty"` Header []*Header `protobuf:"bytes,3,rep,name=header,proto3" json:"header,omitempty"`
AcceptProxyProtocol bool `protobuf:"varint,4,opt,name=accept_proxy_protocol,json=acceptProxyProtocol,proto3" json:"accept_proxy_protocol,omitempty"` AcceptProxyProtocol bool `protobuf:"varint,4,opt,name=accept_proxy_protocol,json=acceptProxyProtocol,proto3" json:"accept_proxy_protocol,omitempty"`
Ed uint32 `protobuf:"varint,5,opt,name=ed,proto3" json:"ed,omitempty"`
} }
func (x *Config) Reset() { func (x *Config) Reset() {
@ -144,6 +145,13 @@ func (x *Config) GetAcceptProxyProtocol() bool {
return false return false
} }
func (x *Config) GetEd() uint32 {
if x != nil {
return x.Ed
}
return 0
}
var File_transport_internet_websocket_config_proto protoreflect.FileDescriptor var File_transport_internet_websocket_config_proto protoreflect.FileDescriptor
var file_transport_internet_websocket_config_proto_rawDesc = []byte{ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
@ -155,7 +163,7 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
0x0a, 0x06, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61,
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
0x22, 0x99, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x22, 0xa9, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70,
0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12,
0x41, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x41, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
0x29, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x29, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74,
@ -164,7 +172,8 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
0x65, 0x72, 0x12, 0x32, 0x0a, 0x15, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x65, 0x72, 0x12, 0x32, 0x0a, 0x15, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x5f, 0x70, 0x72, 0x6f,
0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28,
0x08, 0x52, 0x13, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72, 0x08, 0x52, 0x13, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72,
0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x85, 0x01, 0x0a, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0e, 0x0a, 0x02, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01,
0x28, 0x0d, 0x52, 0x02, 0x65, 0x64, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x85, 0x01, 0x0a,
0x25, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x25, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70,
0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x77, 0x65, 0x62, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x77, 0x65, 0x62,
0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,

View file

@ -20,4 +20,6 @@ message Config {
repeated Header header = 3; repeated Header header = 3;
bool accept_proxy_protocol = 4; bool accept_proxy_protocol = 4;
uint32 ed = 5;
} }

View file

@ -22,10 +22,11 @@ type connection struct {
remoteAddr net.Addr remoteAddr net.Addr
} }
func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection { func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
return &connection{ return &connection{
conn: conn, conn: conn,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
reader: extraReader,
} }
} }

View file

@ -2,9 +2,12 @@ package websocket
import ( import (
"context" "context"
"encoding/base64"
"io"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
@ -15,11 +18,22 @@ import (
// Dial dials a WebSocket connection to the given destination. // Dial dials a WebSocket connection to the given destination.
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
var conn net.Conn
conn, err := dialWebsocket(ctx, dest, streamSettings) if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
if err != nil { ctx, cancel := context.WithCancel(ctx)
conn = &delayDialConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
dest: dest,
streamSettings: streamSettings,
}
} else {
var err error
if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
return nil, newError("failed to dial WebSocket").Base(err) return nil, newError("failed to dial WebSocket").Base(err)
} }
}
return internet.Connection(conn), nil return internet.Connection(conn), nil
} }
@ -27,7 +41,7 @@ func init() {
common.Must(internet.RegisterTransportDialer(protocolName, Dial)) common.Must(internet.RegisterTransportDialer(protocolName, Dial))
} }
func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
wsSettings := streamSettings.ProtocolSettings.(*Config) wsSettings := streamSettings.ProtocolSettings.(*Config)
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
@ -52,7 +66,12 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
} }
uri := protocol + "://" + host + wsSettings.GetNormalizedPath() uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) header := wsSettings.GetRequestHeader()
if ed != nil {
header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))
}
conn, resp, err := dialer.Dial(uri, header)
if err != nil { if err != nil {
var reason string var reason string
if resp != nil { if resp != nil {
@ -61,5 +80,60 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
return nil, newError("failed to dial to (", uri, "): ", reason).Base(err) return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
} }
return newConnection(conn, conn.RemoteAddr()), nil return newConnection(conn, conn.RemoteAddr(), nil), nil
}
type delayDialConn struct {
net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
dest net.Destination
streamSettings *internet.MemoryStreamConfig
}
func (d *delayDialConn) Write(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
ed := b
if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
ed = nil
}
var err error
if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
d.Close()
return 0, newError("failed to dial WebSocket").Base(err)
}
d.dialed <- true
if ed != nil {
return len(ed), nil
}
}
return d.Conn.Write(b)
}
func (d *delayDialConn) Read(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
select {
case <-d.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-d.dialed:
}
}
return d.Conn.Read(b)
}
func (d *delayDialConn) Close() error {
d.closed = true
d.cancel()
if d.Conn == nil {
return nil
}
return d.Conn.Close()
} }

View file

@ -1,8 +1,11 @@
package websocket package websocket
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/base64"
"io"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -51,7 +54,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
} }
} }
h.ln.addConn(newConnection(conn, remoteAddr)) var extraReader io.Reader
if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" {
if ed, err := base64.StdEncoding.DecodeString(str); err == nil && len(ed) > 0 {
extraReader = bytes.NewReader(ed)
}
}
h.ln.addConn(newConnection(conn, remoteAddr, extraReader))
} }
type Listener struct { type Listener struct {

View file

@ -1,6 +1,6 @@
/*Package websocket implements Websocket transport /*Package websocket implements WebSocket transport
Websocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability. WebSocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
*/ */
package websocket package websocket