Enable splice for freedom outbound (downlink only)

- Add outbound name
- Add outbound conn in ctx
- Refactor splice: it can be turn on from all inbounds and outbounds
- Refactor splice: Add splice copy to vless inbound
- Fix http error test
- Add freedom splice toggle via env var
- Populate outbound obj in context
- Use CanSpliceCopy to mark a connection
- Turn off splice by default
This commit is contained in:
yuhan6665 2023-05-03 22:21:45 -04:00
parent ae2fa30e01
commit efd32b0fb2
32 changed files with 282 additions and 168 deletions

View file

@ -218,11 +218,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
if !destination.IsValid() { if !destination.IsValid() {
panic("Dispatcher: Invalid destination.") panic("Dispatcher: Invalid destination.")
} }
ob := &session.Outbound{ ob := session.OutboundFromContext(ctx)
OriginalTarget: destination, if ob == nil {
Target: destination, ob = &session.Outbound{}
}
ctx = session.ContextWithOutbound(ctx, ob) ctx = session.ContextWithOutbound(ctx, ob)
}
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx) content := session.ContentFromContext(ctx)
if content == nil { if content == nil {
content = new(session.Content) content = new(session.Content)
@ -271,11 +273,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if !destination.IsValid() { if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.") return newError("Dispatcher: Invalid destination.")
} }
ob := &session.Outbound{ ob := session.OutboundFromContext(ctx)
OriginalTarget: destination, if ob == nil {
Target: destination, ob = &session.Outbound{}
}
ctx = session.ContextWithOutbound(ctx, ob) ctx = session.ContextWithOutbound(ctx, ob)
}
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx) content := session.ContentFromContext(ctx)
if content == nil { if content == nil {
content = new(session.Content) content = new(session.Content)

View file

@ -60,6 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) {
sid := session.NewID() sid := session.NewID()
ctx = session.ContextWithID(ctx, sid) ctx = session.ContextWithID(ctx, sid)
var outbound = &session.Outbound{}
if w.recvOrigDest { if w.recvOrigDest {
var dest net.Destination var dest net.Destination
switch getTProxyType(w.stream) { switch getTProxyType(w.stream) {
@ -74,11 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) {
dest = net.DestinationFromAddr(conn.LocalAddr()) dest = net.DestinationFromAddr(conn.LocalAddr())
} }
if dest.IsValid() { if dest.IsValid() {
ctx = session.ContextWithOutbound(ctx, &session.Outbound{ outbound.Target = dest
Target: dest,
})
} }
} }
ctx = session.ContextWithOutbound(ctx, outbound)
if w.uplinkCounter != nil || w.downlinkCounter != nil { if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &stat.CounterConnection{ conn = &stat.CounterConnection{

View file

@ -274,7 +274,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
} }
conn, err := internet.Dial(ctx, dest, h.streamSettings) conn, err := internet.Dial(ctx, dest, h.streamSettings)
return h.getStatCouterConnection(conn), err conn = h.getStatCouterConnection(conn)
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Conn = conn
}
return conn, err
} }
func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection { func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection {

View file

@ -6,6 +6,7 @@ import (
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/features/stats"
) )
type dataHandler func(MultiBuffer) type dataHandler func(MultiBuffer)
@ -40,6 +41,17 @@ func CountSize(sc *SizeCounter) CopyOption {
} }
} }
// AddToStatCounter a CopyOption add to stat counter
func AddToStatCounter(sc stats.Counter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
if sc != nil {
sc.Add(int64(b.Len()))
}
})
}
}
type readError struct { type readError struct {
error error
} }

View file

@ -50,6 +50,16 @@ type Inbound struct {
Conn net.Conn Conn net.Conn
// Timer of the inbound buf copier. May be nil. // Timer of the inbound buf copier. May be nil.
Timer *signal.ActivityTimer Timer *signal.ActivityTimer
// CanSpliceCopy is a property for this connection, set by both inbound and outbound
// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
CanSpliceCopy int
}
func(i *Inbound) SetCanSpliceCopy(canSpliceCopy int) int {
if canSpliceCopy > i.CanSpliceCopy {
i.CanSpliceCopy = canSpliceCopy
}
return i.CanSpliceCopy
} }
// Outbound is the metadata of an outbound connection. // Outbound is the metadata of an outbound connection.
@ -60,6 +70,10 @@ type Outbound struct {
RouteTarget net.Destination RouteTarget net.Destination
// Gateway address // Gateway address
Gateway net.Address Gateway net.Address
// Name of the outbound proxy that handles the connection.
Name string
// Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings
Conn net.Conn
} }
// SniffingRequest controls the behavior of content sniffing. // SniffingRequest controls the behavior of content sniffing.

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
) )
@ -30,6 +31,11 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements OutboundHandler.Dispatch(). // Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Name = "blackhole"
}
nBytes := h.response.WriteTo(link.Writer) nBytes := h.response.WriteTo(link.Writer)
if nBytes > 0 { if nBytes > 0 {
// Sleep a little here to make sure the response is sent to client. // Sleep a little here to make sure the response is sent to client.

View file

@ -96,6 +96,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("invalid outbound") return newError("invalid outbound")
} }
outbound.Name = "dns"
srcNetwork := outbound.Target.Network srcNetwork := outbound.Target.Network

View file

@ -102,12 +102,11 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
} }
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.Name = "dokodemo-door" inbound.Name = "dokodemo-door"
inbound.SetCanSpliceCopy(1)
inbound.User = &protocol.MemoryUser{ inbound.User = &protocol.MemoryUser{
Level: d.config.UserLevel, Level: d.config.UserLevel,
} }
}
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: conn.RemoteAddr(), From: conn.RemoteAddr(),

View file

@ -0,0 +1,9 @@
package proxy
import "github.com/xtls/xray-core/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

View file

@ -13,6 +13,7 @@ import (
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/retry" "github.com/xtls/xray-core/common/retry"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal"
@ -21,11 +22,14 @@ import (
"github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
) )
var useSplice bool
func init() { func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
h := new(Handler) h := new(Handler)
@ -36,6 +40,12 @@ func init() {
} }
return h, nil return h, nil
})) }))
const defaultFlagValue = "NOT_DEFINED_AT_ALL"
value := platform.NewEnvFlag("xray.buf.splice").GetValue(func() string { return defaultFlagValue })
switch value {
case "auto", "enable":
useSplice = true
}
} }
// Handler handles Freedom connections. // Handler handles Freedom connections.
@ -107,6 +117,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified.") return newError("target not specified.")
} }
outbound.Name = "freedom"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(1)
}
destination := outbound.Target destination := outbound.Target
UDPOverride := net.UDPDestination(nil, 0) UDPOverride := net.UDPDestination(nil, 0)
if h.config.DestinationOverride != nil { if h.config.DestinationOverride != nil {
@ -195,17 +210,17 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
responseDone := func() error { responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
var reader buf.Reader
if destination.Network == net.Network_TCP { if destination.Network == net.Network_TCP {
reader = buf.NewReader(conn) var writeConn net.Conn
} else { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && useSplice {
reader = NewPacketReader(conn, UDPOverride) writeConn = inbound.Conn
} }
return proxy.CopyRawConnIfExist(ctx, conn, writeConn, link.Writer, timer)
}
reader := NewPacketReader(conn, UDPOverride)
if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
return newError("failed to process response").Base(err) return newError("failed to process response").Base(err)
} }
return nil return nil
} }

View file

@ -73,6 +73,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified.") return newError("target not specified.")
} }
outbound.Name = "http"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(2)
}
target := outbound.Target target := outbound.Target
targetAddr := target.NetAddr() targetAddr := target.NetAddr()

View file

@ -84,12 +84,11 @@ type readerOnly struct {
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.Name = "http" inbound.Name = "http"
inbound.SetCanSpliceCopy(2)
inbound.User = &protocol.MemoryUser{ inbound.User = &protocol.MemoryUser{
Level: s.config.UserLevel, Level: s.config.UserLevel,
} }
}
reader := bufio.NewReaderSize(readerOnly{conn}, buf.Size) reader := bufio.NewReaderSize(readerOnly{conn}, buf.Size)

View file

@ -26,6 +26,7 @@ func (l *Loopback) Process(ctx context.Context, link *transport.Link, _ internet
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified.") return newError("target not specified.")
} }
outbound.Name = "loopback"
destination := outbound.Target destination := outbound.Target
newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx)) newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx))

View file

@ -7,13 +7,24 @@ package proxy
import ( import (
"context" "context"
gotls "crypto/tls"
"io"
"runtime"
"github.com/pires/go-proxyproto"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/reality"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls"
) )
// An Inbound processes inbound connections. // An Inbound processes inbound connections.
@ -47,3 +58,78 @@ type GetInbound interface {
type GetOutbound interface { type GetOutbound interface {
GetOutbound() Outbound GetOutbound() Outbound
} }
// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
var readCounter, writerCounter stats.Counter
if conn != nil {
statConn, ok := conn.(*stat.CounterConnection)
if ok {
conn = statConn.Connection
readCounter = statConn.ReadCounter
writerCounter = statConn.WriteCounter
}
if xc, ok := conn.(*gotls.Conn); ok {
conn = xc.NetConn()
} else if utlsConn, ok := conn.(*tls.UConn); ok {
conn = utlsConn.NetConn()
} else if realityConn, ok := conn.(*reality.Conn); ok {
conn = realityConn.NetConn()
} else if realityUConn, ok := conn.(*reality.UConn); ok {
conn = realityUConn.NetConn()
}
if pc, ok := conn.(*proxyproto.Conn); ok {
conn = pc.Raw()
// 8192 > 4096, there is no need to process pc's bufReader
}
}
return conn, readCounter, writerCounter
}
// CopyRawConnIfExist use the most efficient copy method.
// - If caller don't want to turn on splice, do not pass in both reader conn and writer conn
// - writer are from *transport.Link
func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net.Conn, writer buf.Writer, timer signal.ActivityUpdater) error {
readerConn, readCounter, _ := UnwrapRawConn(readerConn)
writerConn, _, writeCounter := UnwrapRawConn(writerConn)
reader := buf.NewReader(readerConn)
if inbound := session.InboundFromContext(ctx); inbound != nil {
if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
for inbound.CanSpliceCopy != 3 {
if inbound.CanSpliceCopy == 1 {
newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx))
runtime.Gosched() // necessary
w, err := tc.ReadFrom(readerConn)
if readCounter != nil {
readCounter.Add(w)
}
if writeCounter != nil {
writeCounter.Add(w)
}
if err != nil && errors.Cause(err) != io.EOF {
return err
}
return nil
}
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
if readCounter != nil {
readCounter.Add(int64(buffer.Len()))
}
timer.Update()
if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return werr
}
}
if err != nil {
return err
}
}
}
}
newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx))
if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil {
return newError("failed to process response").Base(err)
}
return nil
}

View file

@ -53,6 +53,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified") return newError("target not specified")
} }
outbound.Name = "shadowsocks"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
destination := outbound.Target destination := outbound.Target
network := destination.Network network := destination.Network

View file

@ -71,6 +71,10 @@ func (s *Server) Network() []net.Network {
} }
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks"
inbound.SetCanSpliceCopy(3)
switch network { switch network {
case net.Network_TCP: case net.Network_TCP:
return s.handleConnection(ctx, conn, dispatcher) return s.handleConnection(ctx, conn, dispatcher)
@ -110,13 +114,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
}) })
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
if inbound == nil {
panic("no inbound metadata")
}
inbound.Name = "shadowsocks"
var dest *net.Destination var dest *net.Destination
reader := buf.NewPacketReader(conn) reader := buf.NewPacketReader(conn)
for { for {
mpayload, err := reader.ReadMultiBuffer() mpayload, err := reader.ReadMultiBuffer()

View file

@ -66,6 +66,7 @@ func (i *Inbound) Network() []net.Network {
func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks-2022" inbound.Name = "shadowsocks-2022"
inbound.SetCanSpliceCopy(3)
var metadata M.Metadata var metadata M.Metadata
if inbound.Source.IsValid() { if inbound.Source.IsValid() {

View file

@ -155,6 +155,7 @@ func (i *MultiUserInbound) Network() []net.Network {
func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks-2022-multi" inbound.Name = "shadowsocks-2022-multi"
inbound.SetCanSpliceCopy(3)
var metadata M.Metadata var metadata M.Metadata
if inbound.Source.IsValid() { if inbound.Source.IsValid() {

View file

@ -87,6 +87,7 @@ func (i *RelayInbound) Network() []net.Network {
func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks-2022-relay" inbound.Name = "shadowsocks-2022-relay"
inbound.SetCanSpliceCopy(3)
var metadata M.Metadata var metadata M.Metadata
if inbound.Source.IsValid() { if inbound.Source.IsValid() {

View file

@ -66,12 +66,14 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
if inbound != nil { if inbound != nil {
inboundConn = inbound.Conn inboundConn = inbound.Conn
inbound.SetCanSpliceCopy(3)
} }
outbound := session.OutboundFromContext(ctx) outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified") return newError("target not specified")
} }
outbound.Name = "shadowsocks-2022"
destination := outbound.Target destination := outbound.Target
network := destination.Network network := destination.Network

View file

@ -61,6 +61,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified.") return newError("target not specified.")
} }
outbound.Name = "socks"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(2)
}
// Destination of the inner request. // Destination of the inner request.
destination := outbound.Target destination := outbound.Target

View file

@ -63,12 +63,12 @@ func (s *Server) Network() []net.Network {
// Process implements proxy.Inbound. // Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
if inbound := session.InboundFromContext(ctx); inbound != nil { inbound := session.InboundFromContext(ctx)
inbound.Name = "socks" inbound.Name = "socks"
inbound.SetCanSpliceCopy(2)
inbound.User = &protocol.MemoryUser{ inbound.User = &protocol.MemoryUser{
Level: s.config.UserLevel, Level: s.config.UserLevel,
} }
}
switch network { switch network {
case net.Network_TCP: case net.Network_TCP:

View file

@ -54,6 +54,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified") return newError("target not specified")
} }
outbound.Name = "trojan"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
destination := outbound.Target destination := outbound.Target
network := destination.Network network := destination.Network

View file

@ -214,10 +214,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
} }
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
if inbound == nil {
panic("no inbound metadata")
}
inbound.Name = "trojan" inbound.Name = "trojan"
inbound.SetCanSpliceCopy(3)
inbound.User = user inbound.User = user
sessionPolicy = s.policyManager.ForLevel(user.Level) sessionPolicy = s.policyManager.ForLevel(user.Level)

View file

@ -8,9 +8,7 @@ import (
"crypto/rand" "crypto/rand"
"io" "io"
"math/big" "math/big"
"runtime"
"strconv" "strconv"
"syscall"
"time" "time"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
@ -20,10 +18,8 @@ import (
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless"
"github.com/xtls/xray-core/transport/internet/reality"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls"
) )
const ( const (
@ -206,13 +202,11 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
} }
// XtlsRead filter and read xtls protocol // XtlsRead filter and read xtls protocol
func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer,
input *bytes.Reader, rawInput *bytes.Buffer, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool,
counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool,
isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32, isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32,
) error { ) error {
err := func() error { err := func() error {
var ct stats.Counter
withinPaddingBuffers := true withinPaddingBuffers := true
shouldSwitchToDirectCopy := false shouldSwitchToDirectCopy := false
var remainingContent int32 = -1 var remainingContent int32 = -1
@ -220,40 +214,14 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
currentCommand := 0 currentCommand := 0
for { for {
if shouldSwitchToDirectCopy { if shouldSwitchToDirectCopy {
shouldSwitchToDirectCopy = false var writerConn net.Conn
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
if _, ok := inbound.User.Account.(*vless.MemoryAccount); inbound.User.Account == nil || ok { writerConn = inbound.Conn
iConn := inbound.Conn if inbound.CanSpliceCopy == 2 {
statConn, ok := iConn.(*stat.CounterConnection) inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter
if ok {
iConn = statConn.Connection
}
if tlsConn, ok := iConn.(*tls.Conn); ok {
iConn = tlsConn.NetConn()
} else if realityConn, ok := iConn.(*reality.Conn); ok {
iConn = realityConn.NetConn()
}
if tc, ok := iConn.(*net.TCPConn); ok {
newError("XtlsRead splice").WriteToLog(session.ExportIDToError(ctx))
runtime.Gosched() // necessary
w, err := tc.ReadFrom(conn)
if counter != nil {
counter.Add(w)
}
if statConn != nil && statConn.WriteCounter != nil {
statConn.WriteCounter.Add(w)
}
return err
} }
} }
} return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer)
if rawConn != nil {
reader = buf.NewReadVReader(conn, rawConn, nil)
} else {
reader = buf.NewReader(conn)
}
ct = counter
newError("XtlsRead readV").WriteToLog(session.ExportIDToError(ctx))
} }
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
@ -292,9 +260,6 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
if *numberOfPacketToFilter > 0 { if *numberOfPacketToFilter > 0 {
XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx) XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
} }
if ct != nil {
ct.Add(int64(buffer.Len()))
}
timer.Update() timer.Update()
if werr := writer.WriteMultiBuffer(buffer); werr != nil { if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return werr return werr
@ -312,7 +277,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
} }
// XtlsWrite filter and write xtls protocol // XtlsWrite filter and write xtls protocol
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn,
ctx context.Context, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, ctx context.Context, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
cipher *uint16, remainingServerHello *int32, cipher *uint16, remainingServerHello *int32,
) error { ) error {
@ -349,18 +314,21 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
} }
if shouldSwitchToDirectCopy { if shouldSwitchToDirectCopy {
encryptBuffer, directBuffer := buf.SplitMulti(buffer, xtlsSpecIndex+1) encryptBuffer, directBuffer := buf.SplitMulti(buffer, xtlsSpecIndex+1)
length := encryptBuffer.Len()
if !encryptBuffer.IsEmpty() { if !encryptBuffer.IsEmpty() {
timer.Update() timer.Update()
if werr := writer.WriteMultiBuffer(encryptBuffer); werr != nil { if werr := writer.WriteMultiBuffer(encryptBuffer); werr != nil {
return werr return werr
} }
} }
buffer = directBuffer
writer = buf.NewWriter(conn)
ct = counter
newError("XtlsWrite writeV ", xtlsSpecIndex, " ", length, " ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
time.Sleep(5 * time.Millisecond) // for some device, the first xtls direct packet fails without this delay time.Sleep(5 * time.Millisecond) // for some device, the first xtls direct packet fails without this delay
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter
}
buffer = directBuffer
rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
writer = buf.NewWriter(rawConn)
ct = writerCounter
} }
} }
if !buffer.IsEmpty() { if !buffer.IsEmpty() {

View file

@ -10,11 +10,9 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"syscall"
"time" "time"
"unsafe" "unsafe"
"github.com/pires/go-proxyproto"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
@ -30,7 +28,6 @@ import (
feature_inbound "github.com/xtls/xray-core/features/inbound" feature_inbound "github.com/xtls/xray-core/features/inbound"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless"
"github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/proxy/vless/encoding"
"github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/reality"
@ -182,8 +179,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
sid := session.ExportIDToError(ctx) sid := session.ExportIDToError(ctx)
iConn := connection iConn := connection
statConn, ok := iConn.(*stat.CounterConnection) if statConn, ok := iConn.(*stat.CounterConnection); ok {
if ok {
iConn = statConn.Connection iConn = statConn.Connection
} }
@ -447,14 +443,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
// Flow: requestAddons.Flow, // Flow: requestAddons.Flow,
} }
var netConn net.Conn
var rawConn syscall.RawConn
var input *bytes.Reader var input *bytes.Reader
var rawInput *bytes.Buffer var rawInput *bytes.Buffer
switch requestAddons.Flow { switch requestAddons.Flow {
case vless.XRV: case vless.XRV:
if account.Flow == requestAddons.Flow { if account.Flow == requestAddons.Flow {
inbound.SetCanSpliceCopy(2)
switch request.Command { switch request.Command {
case protocol.RequestCommandUDP: case protocol.RequestCommandUDP:
return newError(requestAddons.Flow + " doesn't support UDP").AtWarning() return newError(requestAddons.Flow + " doesn't support UDP").AtWarning()
@ -467,23 +461,14 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
if tlsConn.ConnectionState().Version != gotls.VersionTLS13 { if tlsConn.ConnectionState().Version != gotls.VersionTLS13 {
return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, tlsConn.ConnectionState().Version).AtWarning() return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, tlsConn.ConnectionState().Version).AtWarning()
} }
netConn = tlsConn.NetConn()
t = reflect.TypeOf(tlsConn.Conn).Elem() t = reflect.TypeOf(tlsConn.Conn).Elem()
p = uintptr(unsafe.Pointer(tlsConn.Conn)) p = uintptr(unsafe.Pointer(tlsConn.Conn))
} else if realityConn, ok := iConn.(*reality.Conn); ok { } else if realityConn, ok := iConn.(*reality.Conn); ok {
netConn = realityConn.NetConn()
t = reflect.TypeOf(realityConn.Conn).Elem() t = reflect.TypeOf(realityConn.Conn).Elem()
p = uintptr(unsafe.Pointer(realityConn.Conn)) p = uintptr(unsafe.Pointer(realityConn.Conn))
} else { } else {
return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning() return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning()
} }
if pc, ok := netConn.(*proxyproto.Conn); ok {
netConn = pc.Raw()
// 8192 > 4096, there is no need to process pc's bufReader
}
if sc, ok := netConn.(syscall.Conn); ok {
rawConn, _ = sc.SyscallConn()
}
i, _ := t.FieldByName("input") i, _ := t.FieldByName("input")
r, _ := t.FieldByName("rawInput") r, _ := t.FieldByName("rawInput")
input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset))
@ -493,6 +478,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning() return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning()
} }
case "": case "":
inbound.SetCanSpliceCopy(3)
if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) { if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) {
return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning() return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning()
} }
@ -540,13 +526,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
var err error var err error
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
var counter stats.Counter ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
if statConn != nil { err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, ctx1, account.ID.Bytes(),
counter = statConn.ReadCounter
}
// TODO enable splice
ctx = session.ContextWithInbound(ctx, nil)
err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, input, rawInput, counter, ctx, account.ID.Bytes(),
&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
@ -592,11 +573,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
var err error var err error
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
var counter stats.Counter err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, ctx, &numberOfPacketToFilter,
if statConn != nil {
counter = statConn.WriteCounter
}
err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter,
&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer

View file

@ -7,7 +7,6 @@ import (
"context" "context"
gotls "crypto/tls" gotls "crypto/tls"
"reflect" "reflect"
"syscall"
"time" "time"
"unsafe" "unsafe"
@ -23,7 +22,6 @@ import (
"github.com/xtls/xray-core/common/xudp" "github.com/xtls/xray-core/common/xudp"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless"
"github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/proxy/vless/encoding"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
@ -71,9 +69,15 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements proxy.Outbound.Process(). // Process implements proxy.Outbound.Process().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified").AtError()
}
outbound.Name = "vless"
inbound := session.InboundFromContext(ctx)
var rec *protocol.ServerSpec var rec *protocol.ServerSpec
var conn stat.Connection var conn stat.Connection
if err := retry.ExponentialBackoff(5, 200).On(func() error { if err := retry.ExponentialBackoff(5, 200).On(func() error {
rec = h.serverPicker.PickServer() rec = h.serverPicker.PickServer()
var err error var err error
@ -88,16 +92,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
defer conn.Close() defer conn.Close()
iConn := conn iConn := conn
statConn, ok := iConn.(*stat.CounterConnection) if statConn, ok := iConn.(*stat.CounterConnection); ok {
if ok {
iConn = statConn.Connection iConn = statConn.Connection
} }
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified").AtError()
}
target := outbound.Target target := outbound.Target
newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx)) newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx))
@ -123,8 +120,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
Flow: account.Flow, Flow: account.Flow,
} }
var netConn net.Conn
var rawConn syscall.RawConn
var input *bytes.Reader var input *bytes.Reader
var rawInput *bytes.Buffer var rawInput *bytes.Buffer
allowUDP443 := false allowUDP443 := false
@ -134,6 +129,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
requestAddons.Flow = requestAddons.Flow[:16] requestAddons.Flow = requestAddons.Flow[:16]
fallthrough fallthrough
case vless.XRV: case vless.XRV:
if inbound != nil {
inbound.SetCanSpliceCopy(2)
}
switch request.Command { switch request.Command {
case protocol.RequestCommandUDP: case protocol.RequestCommandUDP:
if !allowUDP443 && request.Port == 443 { if !allowUDP443 && request.Port == 443 {
@ -146,28 +144,26 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
var t reflect.Type var t reflect.Type
var p uintptr var p uintptr
if tlsConn, ok := iConn.(*tls.Conn); ok { if tlsConn, ok := iConn.(*tls.Conn); ok {
netConn = tlsConn.NetConn()
t = reflect.TypeOf(tlsConn.Conn).Elem() t = reflect.TypeOf(tlsConn.Conn).Elem()
p = uintptr(unsafe.Pointer(tlsConn.Conn)) p = uintptr(unsafe.Pointer(tlsConn.Conn))
} else if utlsConn, ok := iConn.(*tls.UConn); ok { } else if utlsConn, ok := iConn.(*tls.UConn); ok {
netConn = utlsConn.NetConn()
t = reflect.TypeOf(utlsConn.Conn).Elem() t = reflect.TypeOf(utlsConn.Conn).Elem()
p = uintptr(unsafe.Pointer(utlsConn.Conn)) p = uintptr(unsafe.Pointer(utlsConn.Conn))
} else if realityConn, ok := iConn.(*reality.UConn); ok { } else if realityConn, ok := iConn.(*reality.UConn); ok {
netConn = realityConn.NetConn()
t = reflect.TypeOf(realityConn.Conn).Elem() t = reflect.TypeOf(realityConn.Conn).Elem()
p = uintptr(unsafe.Pointer(realityConn.Conn)) p = uintptr(unsafe.Pointer(realityConn.Conn))
} else { } else {
return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning() return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning()
} }
if sc, ok := netConn.(syscall.Conn); ok {
rawConn, _ = sc.SyscallConn()
}
i, _ := t.FieldByName("input") i, _ := t.FieldByName("input")
r, _ := t.FieldByName("rawInput") r, _ := t.FieldByName("rawInput")
input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset))
rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset))
} }
default:
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
} }
var newCtx context.Context var newCtx context.Context
@ -257,11 +253,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning() return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning()
} }
} }
var counter stats.Counter ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice
if statConn != nil { err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, ctx1, &numberOfPacketToFilter,
counter = statConn.WriteCounter
}
err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter,
&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
@ -293,11 +286,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
var counter stats.Counter err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, ctx, account.ID.Bytes(),
if statConn != nil {
counter = statConn.ReadCounter
}
err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, input, rawInput, counter, ctx, account.ID.Bytes(),
&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer

View file

@ -256,10 +256,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
} }
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
if inbound == nil {
panic("no inbound metadata")
}
inbound.Name = "vmess" inbound.Name = "vmess"
inbound.SetCanSpliceCopy(3)
inbound.User = request.User inbound.User = request.User
sessionPolicy = h.policyManager.ForLevel(request.User.Level) sessionPolicy = h.policyManager.ForLevel(request.User.Level)

View file

@ -60,9 +60,18 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements proxy.Outbound.Process(). // Process implements proxy.Outbound.Process().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified").AtError()
}
outbound.Name = "vmess"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
var rec *protocol.ServerSpec var rec *protocol.ServerSpec
var conn stat.Connection var conn stat.Connection
err := retry.ExponentialBackoff(5, 200).On(func() error { err := retry.ExponentialBackoff(5, 200).On(func() error {
rec = h.serverPicker.PickServer() rec = h.serverPicker.PickServer()
rawConn, err := dialer.Dial(ctx, rec.Destination()) rawConn, err := dialer.Dial(ctx, rec.Destination())
@ -78,11 +87,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
} }
defer conn.Close() defer conn.Close()
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified").AtError()
}
target := outbound.Target target := outbound.Target
newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx)) newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx))

View file

@ -75,6 +75,16 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
// Process implements OutboundHandler.Dispatch(). // Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "wireguard"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
if h.bind == nil || h.bind.dialer != dialer || h.net == nil { if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
log.Record(&log.GeneralMessage{ log.Record(&log.GeneralMessage{
Severity: log.Severity_Info, Severity: log.Severity_Info,
@ -101,10 +111,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
h.bind = bind h.bind = bind
} }
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified")
}
// Destination of the inner request. // Destination of the inner request.
destination := outbound.Target destination := outbound.Target
command := protocol.RequestCommandTCP command := protocol.RequestCommandTCP

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"testing" "testing"
"time" "time"
@ -128,9 +129,8 @@ func TestHttpError(t *testing.T) {
} }
resp, err := client.Get("http://127.0.0.1:" + dest.Port.String()) resp, err := client.Get("http://127.0.0.1:" + dest.Port.String())
common.Must(err) if resp != nil && resp.StatusCode != 503 || err != nil && !strings.Contains(err.Error(), "malformed HTTP status code") {
if resp.StatusCode != 503 { t.Error("should not receive http response", err)
t.Error("status: ", resp.StatusCode)
} }
} }
} }

View file

@ -1174,10 +1174,10 @@ func TestVMessGCMMuxUDP(t *testing.T) {
servers, err := InitializeServerConfigs(serverConfig, clientConfig) servers, err := InitializeServerConfigs(serverConfig, clientConfig)
common.Must(err) common.Must(err)
for range "abcd" { for range "ab" {
var errg errgroup.Group var errg errgroup.Group
for i := 0; i < 16; i++ { for i := 0; i < 16; i++ {
errg.Go(testTCPConn(clientPort, 10240, time.Second*20)) errg.Go(testTCPConn(clientPort, 1024, time.Second*10))
errg.Go(testUDPConn(clientUDPPort, 1024, time.Second*10)) errg.Go(testUDPConn(clientUDPPort, 1024, time.Second*10))
} }
if err := errg.Wait(); err != nil { if err := errg.Wait(); err != nil {