package httpupgrade_test import ( "context" "runtime" "testing" "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/testing/servers/tcp" "github.com/xtls/xray-core/transport/internet" . "github.com/xtls/xray-core/transport/internet/httpupgrade" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" ) func Test_listenHTTPUpgradeAndDial(t *testing.T) { listenPort := tcp.PickPort() listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{ Path: "httpupgrade", }, }, func(conn stat.Connection) { go func(c stat.Connection) { defer c.Close() var b [1024]byte c.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err := c.Read(b[:]) if err != nil { return } common.Must2(c.Write([]byte("Response"))) }(conn) }) common.Must(err) ctx := context.Background() streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{Path: "httpupgrade"}, } conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) _, err = conn.Write([]byte("Test connection 1")) common.Must(err) var b [1024]byte n, err := conn.Read(b[:]) common.Must(err) if string(b[:n]) != "Response" { t.Error("response: ", string(b[:n])) } common.Must(conn.Close()) conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) _, err = conn.Write([]byte("Test connection 2")) common.Must(err) n, err = conn.Read(b[:]) common.Must(err) if string(b[:n]) != "Response" { t.Error("response: ", string(b[:n])) } common.Must(conn.Close()) common.Must(listen.Close()) } func Test_listenHTTPUpgradeAndDialWithHeaders(t *testing.T) { listenPort := tcp.PickPort() listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{ Path: "httpupgrade", Header: map[string]string{ "User-Agent": "Mozilla", }, }, }, func(conn stat.Connection) { go func(c stat.Connection) { defer c.Close() var b [1024]byte c.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err := c.Read(b[:]) if err != nil { return } common.Must2(c.Write([]byte("Response"))) }(conn) }) common.Must(err) ctx := context.Background() streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{Path: "httpupgrade"}, } conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) _, err = conn.Write([]byte("Test connection 1")) common.Must(err) var b [1024]byte n, err := conn.Read(b[:]) common.Must(err) if string(b[:n]) != "Response" { t.Error("response: ", string(b[:n])) } common.Must(conn.Close()) conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) _, err = conn.Write([]byte("Test connection 2")) common.Must(err) n, err = conn.Read(b[:]) common.Must(err) if string(b[:n]) != "Response" { t.Error("response: ", string(b[:n])) } common.Must(conn.Close()) common.Must(listen.Close()) } func TestDialWithRemoteAddr(t *testing.T) { listenPort := tcp.PickPort() listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{ Path: "httpupgrade", }, }, func(conn stat.Connection) { go func(c stat.Connection) { defer c.Close() var b [1024]byte _, err := c.Read(b[:]) // common.Must(err) if err != nil { return } _, err = c.Write([]byte(c.RemoteAddr().String())) common.Must(err) }(conn) }) common.Must(err) conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{Path: "httpupgrade", Header: map[string]string{"X-Forwarded-For": "1.1.1.1"}}, }) common.Must(err) _, err = conn.Write([]byte("Test connection 1")) common.Must(err) var b [1024]byte n, err := conn.Read(b[:]) common.Must(err) if string(b[:n]) != "1.1.1.1:0" { t.Error("response: ", string(b[:n])) } common.Must(listen.Close()) } func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) { listenPort := tcp.PickPort() if runtime.GOARCH == "arm64" { return } start := time.Now() streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "httpupgrade", ProtocolSettings: &Config{ Path: "httpupgrades", }, SecurityType: "tls", SecuritySettings: &tls.Config{ AllowInsecure: true, Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, }, } listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { _ = conn.Close() }() }) common.Must(err) defer listen.Close() conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) _ = conn.Close() end := time.Now() if !end.Before(start.Add(time.Second * 5)) { t.Error("end: ", end, " start: ", start) } }