diff --git a/adapter/inbound.go b/adapter/inbound.go index aeef3798..6d35c757 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -2,8 +2,11 @@ package adapter import ( "context" + "net/netip" M "github.com/sagernet/sing/common/metadata" + + C "github.com/sagernet/sing-box/constant" ) type Inbound interface { @@ -23,8 +26,10 @@ type InboundContext struct { // cache + DomainStrategy C.DomainStrategy SniffEnabled bool SniffOverrideDestination bool + DestinationAddresses []netip.Addr SourceGeoIPCode string GeoIPCode string @@ -50,5 +55,5 @@ func AppendContext(ctx context.Context) (context.Context, *InboundContext) { return ctx, metadata } metadata = new(InboundContext) - return WithContext(ctx, metadata), nil + return WithContext(ctx, metadata), metadata } diff --git a/adapter/outbound.go b/adapter/outbound.go index 8be9f9f8..1bfd8ba8 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -4,7 +4,6 @@ import ( "context" "net" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -13,6 +12,6 @@ type Outbound interface { Tag() string Network() []string N.Dialer - NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error - NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error + NewConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error + NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error } diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index 2039729d..f2f7fec1 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -10,16 +10,18 @@ import ( ) func New(router adapter.Router, options option.DialerOptions) N.Dialer { - domainStrategy := C.DomainStrategy(options.DomainStrategy) - var dialer N.Dialer if options.Detour == "" { - dialer = NewDefault(options) - dialer = NewResolveDialer(router, dialer, domainStrategy) + return NewDefault(options) } else { - dialer = NewDetour(router, options.Detour) - if domainStrategy != C.DomainStrategyAsIS { - dialer = NewResolveDialer(router, dialer, domainStrategy) - } + return NewDetour(router, options.Detour) + } +} + +func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.Dialer { + dialer := New(router, options.DialerOptions) + domainStrategy := C.DomainStrategy(options.DomainStrategy) + if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED { + dialer = NewResolveDialer(router, dialer, domainStrategy) } if options.OverrideOptions.IsValid() { dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions)) diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index 2584c568..e598be79 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -41,16 +40,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina if err != nil { return nil, err } - var conn net.Conn - var connErrors []error - for _, address := range addresses { - conn, err = d.dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) - if err != nil { - connErrors = append(connErrors, err) - } - return conn, nil - } - return nil, E.Errors(connErrors...) + return DialSerial(ctx, d.dialer, network, destination, addresses) } func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { @@ -67,16 +57,7 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd if err != nil { return nil, err } - var conn net.PacketConn - var connErrors []error - for _, address := range addresses { - conn, err = d.dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port)) - if err != nil { - connErrors = append(connErrors, err) - } - return conn, nil - } - return nil, E.Errors(connErrors...) + return ListenSerial(ctx, d.dialer, destination, addresses) } func (d *ResolveDialer) Upstream() any { diff --git a/common/dialer/serial.go b/common/dialer/serial.go new file mode 100644 index 00000000..b5508e94 --- /dev/null +++ b/common/dialer/serial.go @@ -0,0 +1,39 @@ +package dialer + +import ( + "context" + "net" + "net/netip" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func DialSerial(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) { + var conn net.Conn + var err error + var connErrors []error + for _, address := range destinationAddresses { + conn, err = dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) + if err != nil { + connErrors = append(connErrors, err) + } + return conn, nil + } + return nil, E.Errors(connErrors...) +} + +func ListenSerial(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.PacketConn, error) { + var conn net.PacketConn + var err error + var connErrors []error + for _, address := range destinationAddresses { + conn, err = dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port)) + if err != nil { + connErrors = append(connErrors, err) + } + return conn, nil + } + return nil, E.Errors(connErrors...) +} diff --git a/common/geosite/reader.go b/common/geosite/reader.go index 85b22fa5..a1b39f28 100644 --- a/common/geosite/reader.go +++ b/common/geosite/reader.go @@ -77,9 +77,14 @@ func (r *Reader) readMetadata() error { } func (r *Reader) Read(code string) ([]Item, error) { - if _, exists := r.domainIndex[code]; !exists { + index, exists := r.domainIndex[code] + if !exists { return nil, E.New("code ", code, " not exists!") } + _, err := r.reader.Seek(int64(index), io.SeekCurrent) + if err != nil { + return nil, err + } counter := &rw.ReadCounter{Reader: r.reader} domain := make([]Item, r.domainLength[code]) for i := range domain { @@ -97,7 +102,7 @@ func (r *Reader) Read(code string) ([]Item, error) { } domain[i] = item } - _, err := r.reader.Seek(int64(r.domainIndex[code])-counter.Count(), io.SeekCurrent) + _, err = r.reader.Seek(int64(-index)-counter.Count(), io.SeekCurrent) return domain, err } diff --git a/constant/cgo.go b/constant/cgo.go new file mode 100644 index 00000000..d6ce7035 --- /dev/null +++ b/constant/cgo.go @@ -0,0 +1,3 @@ +package constant + +const CGO_ENABLED = true diff --git a/constant/cgo_disabled.go b/constant/cgo_disabled.go new file mode 100644 index 00000000..51cacad5 --- /dev/null +++ b/constant/cgo_disabled.go @@ -0,0 +1,5 @@ +//go:build !cgo + +package constant + +const CGO_ENABLED = false diff --git a/dns/dialer.go b/dns/dialer.go index d7f47e42..a2a73029 100644 --- a/dns/dialer.go +++ b/dns/dialer.go @@ -4,11 +4,11 @@ import ( "context" "net" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" ) @@ -31,16 +31,7 @@ func (d *DialerWrapper) DialContext(ctx context.Context, network string, destina if err != nil { return nil, err } - var conn net.Conn - var connErrors []error - for _, address := range addresses { - conn, err = d.dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) - if err != nil { - connErrors = append(connErrors, err) - } - return conn, nil - } - return nil, E.Errors(connErrors...) + return dialer.DialSerial(ctx, d.dialer, network, destination, addresses) } func (d *DialerWrapper) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { @@ -51,16 +42,7 @@ func (d *DialerWrapper) ListenPacket(ctx context.Context, destination M.Socksadd if err != nil { return nil, err } - var conn net.PacketConn - var connErrors []error - for _, address := range addresses { - conn, err = d.dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port)) - if err != nil { - connErrors = append(connErrors, err) - } - return conn, nil - } - return nil, E.Errors(connErrors...) + return dialer.ListenSerial(ctx, d.dialer, destination, addresses) } func (d *DialerWrapper) Upstream() any { diff --git a/inbound/default.go b/inbound/default.go index 40ad7507..5c7d5f66 100644 --- a/inbound/default.go +++ b/inbound/default.go @@ -136,6 +136,7 @@ func (a *myInboundAdapter) loopTCPIn() { metadata.Inbound = a.tag metadata.SniffEnabled = a.listenOptions.SniffEnabled metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = C.DomainStrategy(a.listenOptions.DomainStrategy) metadata.Network = C.NetworkTCP metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()) a.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source) @@ -167,6 +168,7 @@ func (a *myInboundAdapter) loopUDPIn() { metadata.Inbound = a.tag metadata.SniffEnabled = a.listenOptions.SniffEnabled metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = C.DomainStrategy(a.listenOptions.DomainStrategy) metadata.Network = C.NetworkUDP metadata.Source = M.SocksaddrFromNetIP(addr) err = a.packetHandler.NewPacket(a.ctx, packetService, buffer, metadata) @@ -191,6 +193,7 @@ func (a *myInboundAdapter) loopUDPInThreadSafe() { metadata.Inbound = a.tag metadata.SniffEnabled = a.listenOptions.SniffEnabled metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = C.DomainStrategy(a.listenOptions.DomainStrategy) metadata.Network = C.NetworkUDP metadata.Source = M.SocksaddrFromNetIP(addr) err = a.packetHandler.NewPacket(a.ctx, packetService, buffer, metadata) diff --git a/option/inbound.go b/option/inbound.go index 5a701397..65c895d1 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -79,12 +79,13 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error { } type ListenOptions struct { - Listen ListenAddress `json:"listen"` - Port uint16 `json:"listen_port"` - TCPFastOpen bool `json:"tcp_fast_open,omitempty"` - UDPTimeout int64 `json:"udp_timeout,omitempty"` - SniffEnabled bool `json:"sniff,omitempty"` - SniffOverrideDestination bool `json:"sniff_override_destination,omitempty"` + Listen ListenAddress `json:"listen"` + Port uint16 `json:"listen_port"` + TCPFastOpen bool `json:"tcp_fast_open,omitempty"` + UDPTimeout int64 `json:"udp_timeout,omitempty"` + SniffEnabled bool `json:"sniff,omitempty"` + SniffOverrideDestination bool `json:"sniff_override_destination,omitempty"` + DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` } type SimpleInboundOptions struct { diff --git a/option/outbound.go b/option/outbound.go index 5f297c16..c2376b72 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -67,13 +67,17 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error { } type DialerOptions struct { - Detour string `json:"detour,omitempty"` - BindInterface string `json:"bind_interface,omitempty"` - ProtectPath string `json:"protect_path,omitempty"` - RoutingMark int `json:"routing_mark,omitempty"` - ReuseAddr bool `json:"reuse_addr,omitempty"` - ConnectTimeout int `json:"connect_timeout,omitempty"` - TCPFastOpen bool `json:"tcp_fast_open,omitempty"` + Detour string `json:"detour,omitempty"` + BindInterface string `json:"bind_interface,omitempty"` + ProtectPath string `json:"protect_path,omitempty"` + RoutingMark int `json:"routing_mark,omitempty"` + ReuseAddr bool `json:"reuse_addr,omitempty"` + ConnectTimeout int `json:"connect_timeout,omitempty"` + TCPFastOpen bool `json:"tcp_fast_open,omitempty"` +} + +type OutboundDialerOptions struct { + DialerOptions OverrideOptions *OverrideStreamOptions `json:"override,omitempty"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` } @@ -99,13 +103,13 @@ func (o ServerOptions) Build() M.Socksaddr { } type DirectOutboundOptions struct { - DialerOptions + OutboundDialerOptions OverrideAddress string `json:"override_address,omitempty"` OverridePort uint16 `json:"override_port,omitempty"` } type SocksOutboundOptions struct { - DialerOptions + OutboundDialerOptions ServerOptions Version string `json:"version,omitempty"` Username string `json:"username,omitempty"` @@ -114,14 +118,14 @@ type SocksOutboundOptions struct { } type HTTPOutboundOptions struct { - DialerOptions + OutboundDialerOptions ServerOptions Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` } type ShadowsocksOutboundOptions struct { - DialerOptions + OutboundDialerOptions ServerOptions Method string `json:"method"` Password string `json:"password"` diff --git a/option/types.go b/option/types.go index 81f6f06e..dca2ddd4 100644 --- a/option/types.go +++ b/option/types.go @@ -100,13 +100,13 @@ func (s DomainStrategy) MarshalJSON() ([]byte, error) { value = "" // value = "AsIS" case C.DomainStrategyPreferIPv4: - value = "PreferIPv4" + value = "prefer_ipv4" case C.DomainStrategyPreferIPv6: - value = "PreferIPv6" + value = "prefer_ipv6" case C.DomainStrategyUseIPv4: - value = "UseIPv4" + value = "ipv4_only" case C.DomainStrategyUseIPv6: - value = "UseIPv6" + value = "ipv6_only" default: return nil, E.New("unknown domain strategy: ", s) } @@ -122,13 +122,13 @@ func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error { switch value { case "", "AsIS": *s = DomainStrategy(C.DomainStrategyAsIS) - case "PreferIPv4": + case "prefer_ipv4": *s = DomainStrategy(C.DomainStrategyPreferIPv4) - case "PreferIPv6": + case "prefer_ipv6": *s = DomainStrategy(C.DomainStrategyPreferIPv6) - case "UseIPv4": + case "ipv4_only": *s = DomainStrategy(C.DomainStrategyUseIPv4) - case "UseIPv6": + case "ipv6_only": *s = DomainStrategy(C.DomainStrategyUseIPv6) default: return E.New("unknown domain strategy: ", value) diff --git a/outbound/block.go b/outbound/block.go index a54c1944..8e786442 100644 --- a/outbound/block.go +++ b/outbound/block.go @@ -40,14 +40,14 @@ func (h *Block) ListenPacket(ctx context.Context, destination M.Socksaddr) (net. return nil, io.EOF } -func (h *Block) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { +func (h *Block) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { conn.Close() - h.logger.WithContext(ctx).Info("blocked connection to ", destination) + h.logger.WithContext(ctx).Info("blocked connection to ", metadata.Destination) return nil } -func (h *Block) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { +func (h *Block) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { conn.Close() - h.logger.WithContext(ctx).Info("blocked packet connection to ", destination) + h.logger.WithContext(ctx).Info("blocked packet connection to ", metadata.Destination) return nil } diff --git a/outbound/default.go b/outbound/default.go index 7aa36b76..c5c0265f 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -10,7 +10,11 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" ) @@ -33,6 +37,51 @@ func (a *myOutboundAdapter) Network() []string { return a.network } +func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + var outConn net.Conn + var err error + if len(metadata.DestinationAddresses) > 0 { + outConn, err = dialer.DialSerial(ctx, this, C.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) + } else { + outConn, err = this.DialContext(ctx, C.NetworkTCP, metadata.Destination) + } + if err != nil { + return err + } + return bufio.CopyConn(ctx, conn, outConn) +} + +func NewEarlyConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + var outConn net.Conn + var err error + if len(metadata.DestinationAddresses) > 0 { + outConn, err = dialer.DialSerial(ctx, this, C.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) + } else { + outConn, err = this.DialContext(ctx, C.NetworkTCP, metadata.Destination) + } + if err != nil { + return err + } + return CopyEarlyConn(ctx, conn, outConn) +} + +func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + var outConn net.PacketConn + var err error + if len(metadata.DestinationAddresses) > 0 { + outConn, err = dialer.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + } else { + outConn, err = this.ListenPacket(ctx, metadata.Destination) + } + if err != nil { + return err + } + return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) +} + func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { _payload := buf.StackNew() payload := common.Dup(_payload) diff --git a/outbound/direct.go b/outbound/direct.go index 0c489af1..97749cff 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -4,7 +4,6 @@ import ( "context" "net" - "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -32,7 +31,7 @@ func NewDirect(router adapter.Router, logger log.Logger, tag string, options opt tag: tag, network: []string{C.NetworkTCP, C.NetworkUDP}, }, - dialer: dialer.New(router, options.DialerOptions), + dialer: dialer.NewOutbound(router, options.OutboundDialerOptions), } if options.OverrideAddress != "" && options.OverridePort != 0 { outbound.overrideOption = 1 @@ -50,6 +49,7 @@ func NewDirect(router adapter.Router, logger log.Logger, tag string, options opt func (h *Direct) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination switch h.overrideOption { case 1: destination = h.overrideDestination @@ -72,22 +72,15 @@ func (h *Direct) DialContext(ctx context.Context, network string, destination M. func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination h.logger.WithContext(ctx).Info("outbound packet connection") return h.dialer.ListenPacket(ctx, destination) } -func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { - outConn, err := h.DialContext(ctx, C.NetworkTCP, destination) - if err != nil { - return err - } - return bufio.CopyConn(ctx, conn, outConn) +func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewConnection(ctx, h, conn, metadata) } -func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { - outConn, err := h.ListenPacket(ctx, destination) - if err != nil { - return err - } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) +func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewPacketConnection(ctx, h, conn, metadata) } diff --git a/outbound/http.go b/outbound/http.go index 284651bd..2833b34a 100644 --- a/outbound/http.go +++ b/outbound/http.go @@ -5,7 +5,6 @@ import ( "net" "os" - "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/protocol/http" @@ -32,13 +31,14 @@ func NewHTTP(router adapter.Router, logger log.Logger, tag string, options optio tag: tag, network: []string{C.NetworkTCP}, }, - http.NewClient(dialer.New(router, options.DialerOptions), options.ServerOptions.Build(), options.Username, options.Password), + http.NewClient(dialer.NewOutbound(router, options.OutboundDialerOptions), options.ServerOptions.Build(), options.Username, options.Password), } } func (h *HTTP) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination h.logger.WithContext(ctx).Info("outbound connection to ", destination) return h.client.DialContext(ctx, network, destination) } @@ -46,17 +46,14 @@ func (h *HTTP) DialContext(ctx context.Context, network string, destination M.So func (h *HTTP) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination return nil, os.ErrInvalid } -func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { - outConn, err := h.DialContext(ctx, C.NetworkTCP, destination) - if err != nil { - return err - } - return bufio.CopyConn(ctx, conn, outConn) +func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewConnection(ctx, h, conn, metadata) } -func (h *HTTP) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { +func (h *HTTP) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return os.ErrInvalid } diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index 8f3d159f..81510842 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -39,7 +39,7 @@ func NewShadowsocks(router adapter.Router, logger log.Logger, tag string, option tag: tag, network: options.Network.Build(), }, - dialer.New(router, options.DialerOptions), + dialer.NewOutbound(router, options.OutboundDialerOptions), method, options.ServerOptions.Build(), }, nil @@ -48,6 +48,7 @@ func NewShadowsocks(router adapter.Router, logger log.Logger, tag string, option func (h *Shadowsocks) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination switch network { case C.NetworkTCP: h.logger.WithContext(ctx).Info("outbound connection to ", destination) @@ -71,6 +72,7 @@ func (h *Shadowsocks) DialContext(ctx context.Context, network string, destinati func (h *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination h.logger.WithContext(ctx).Info("outbound packet connection to ", h.serverAddr) outConn, err := h.dialer.ListenPacket(ctx, destination) if err != nil { @@ -79,18 +81,10 @@ func (h *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr) return h.method.DialPacketConn(&bufio.BindPacketConn{PacketConn: outConn, Addr: h.serverAddr.UDPAddr()}), nil } -func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { - serverConn, err := h.DialContext(ctx, C.NetworkTCP, destination) - if err != nil { - return err - } - return CopyEarlyConn(ctx, conn, serverConn) +func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewEarlyConnection(ctx, h, conn, metadata) } -func (h *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { - serverConn, err := h.ListenPacket(ctx, destination) - if err != nil { - return err - } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(serverConn)) +func (h *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewPacketConnection(ctx, h, conn, metadata) } diff --git a/outbound/socks.go b/outbound/socks.go index 6ca959a0..833c7c94 100644 --- a/outbound/socks.go +++ b/outbound/socks.go @@ -4,7 +4,6 @@ import ( "context" "net" - "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/protocol/socks" @@ -24,7 +23,7 @@ type Socks struct { } func NewSocks(router adapter.Router, logger log.Logger, tag string, options option.SocksOutboundOptions) (*Socks, error) { - detour := dialer.New(router, options.DialerOptions) + detour := dialer.NewOutbound(router, options.OutboundDialerOptions) var version socks.Version var err error if options.Version != "" { @@ -49,6 +48,7 @@ func NewSocks(router adapter.Router, logger log.Logger, tag string, options opti func (h *Socks) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination switch network { case C.NetworkTCP: h.logger.WithContext(ctx).Info("outbound connection to ", destination) @@ -63,22 +63,15 @@ func (h *Socks) DialContext(ctx context.Context, network string, destination M.S func (h *Socks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag + metadata.Destination = destination h.logger.WithContext(ctx).Info("outbound packet connection to ", destination) return h.client.ListenPacket(ctx, destination) } -func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { - outConn, err := h.DialContext(ctx, C.NetworkTCP, destination) - if err != nil { - return err - } - return bufio.CopyConn(ctx, conn, outConn) +func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewConnection(ctx, h, conn, metadata) } -func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { - outConn, err := h.ListenPacket(ctx, destination) - if err != nil { - return err - } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) +func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewPacketConnection(ctx, h, conn, metadata) } diff --git a/route/router.go b/route/router.go index 3172b412..a1035850 100644 --- a/route/router.go +++ b/route/router.go @@ -56,6 +56,7 @@ type Router struct { dnsClient adapter.DNSClient defaultDomainStrategy C.DomainStrategy + dnsRules []adapter.Rule defaultTransport adapter.DNSTransport transports []adapter.DNSTransport @@ -69,9 +70,11 @@ func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptio dnsLogger: logger.WithPrefix("dns: "), outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), + dnsRules: make([]adapter.Rule, 0, len(dnsOptions.Rules)), needGeoIPDatabase: hasGeoRule(options.Rules, isGeoIPRule) || hasGeoDNSRule(dnsOptions.Rules, isGeoIPDNSRule), needGeositeDatabase: hasGeoRule(options.Rules, isGeositeRule) || hasGeoDNSRule(dnsOptions.Rules, isGeositeDNSRule), geoIPOptions: common.PtrValueOrDefault(options.GeoIP), + geositeOptions: common.PtrValueOrDefault(options.Geosite), defaultDetour: options.Final, dnsClient: dns.NewClient(dnsOptions.DNSClientOptions), defaultDomainStrategy: C.DomainStrategy(dnsOptions.Strategy), @@ -88,7 +91,7 @@ func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptio if err != nil { return nil, E.Cause(err, "parse dns rule[", i, "]") } - router.rules = append(router.rules, dnsRule) + router.dnsRules = append(router.dnsRules, dnsRule) } transports := make([]adapter.DNSTransport, len(dnsOptions.Servers)) dummyTransportMap := make(map[string]adapter.DNSTransport) @@ -259,6 +262,12 @@ func (r *Router) Start() error { return err } } + for _, rule := range r.dnsRules { + err := rule.Start() + if err != nil { + return err + } + } if r.needGeositeDatabase { for _, rule := range r.rules { err := rule.UpdateGeosite() @@ -266,6 +275,12 @@ func (r *Router) Start() error { r.logger.Error("failed to initialize geosite: ", err) } } + for _, rule := range r.dnsRules { + err := rule.UpdateGeosite() + if err != nil { + r.logger.Error("failed to initialize geosite: ", err) + } + } err := common.Close(r.geositeReader) if err != nil { return err @@ -275,6 +290,18 @@ func (r *Router) Start() error { } func (r *Router) Close() error { + for _, rule := range r.rules { + err := rule.Close() + if err != nil { + return err + } + } + for _, rule := range r.dnsRules { + err := rule.Close() + if err != nil { + return err + } + } return common.Close( common.PtrOrNil(r.geoIPReader), ) @@ -325,12 +352,20 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad conn = bufio.NewCachedConn(conn, buffer) } } + if metadata.Destination.IsFqdn() && metadata.DomainStrategy != C.DomainStrategyAsIS { + addresses, err := r.Lookup(adapter.WithContext(ctx, &metadata), metadata.Destination.Fqdn, metadata.DomainStrategy) + if err != nil { + return err + } + metadata.DestinationAddresses = addresses + r.dnsLogger.WithContext(ctx).Info("resolved [", strings.Join(common.Map(metadata.DestinationAddresses, F.ToString0[netip.Addr]), " "), "]") + } detour := r.match(ctx, metadata, r.defaultOutboundForConnection) if !common.Contains(detour.Network(), C.NetworkTCP) { conn.Close() return E.New("missing supported outbound, closing connection") } - return detour.NewConnection(adapter.WithContext(ctx, &metadata), conn, metadata.Destination) + return detour.NewConnection(ctx, conn, metadata) } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { @@ -359,12 +394,20 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m } conn = bufio.NewCachedPacketConn(conn, buffer, originDestination) } + if metadata.Destination.IsFqdn() && metadata.DomainStrategy != C.DomainStrategyAsIS { + addresses, err := r.Lookup(adapter.WithContext(ctx, &metadata), metadata.Destination.Fqdn, metadata.DomainStrategy) + if err != nil { + return err + } + metadata.DestinationAddresses = addresses + r.dnsLogger.WithContext(ctx).Info("resolved [", strings.Join(common.Map(metadata.DestinationAddresses, F.ToString0[netip.Addr]), " "), "]") + } detour := r.match(ctx, metadata, r.defaultOutboundForPacketConnection) if !common.Contains(detour.Network(), C.NetworkUDP) { conn.Close() return E.New("missing supported outbound, closing packet connection") } - return detour.NewPacketConnection(adapter.WithContext(ctx, &metadata), conn, metadata.Destination) + return detour.NewPacketConnection(ctx, conn, metadata) } func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) { @@ -397,10 +440,10 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport { metadata := adapter.ContextFrom(ctx) if metadata == nil { - r.dnsLogger.WithContext(ctx).Info("no context") + r.dnsLogger.WithContext(ctx).Warn("no context") return r.defaultTransport } - for i, rule := range r.rules { + for i, rule := range r.dnsRules { if rule.Match(metadata) { detour := rule.Outbound() r.dnsLogger.WithContext(ctx).Info("match[", i, "] ", rule.String(), " => ", detour) diff --git a/route/rule_cidr.go b/route/rule_cidr.go index 8514593c..05636bdb 100644 --- a/route/rule_cidr.go +++ b/route/rule_cidr.go @@ -41,12 +41,19 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { } } } else { - if metadata.Destination.IsFqdn() { - return false - } - for _, prefix := range r.prefixes { - if prefix.Contains(metadata.Destination.Addr) { - return true + if metadata.Destination.IsIP() { + for _, prefix := range r.prefixes { + if prefix.Contains(metadata.Destination.Addr) { + return true + } + } + } else { + for _, address := range metadata.DestinationAddresses { + for _, prefix := range r.prefixes { + if prefix.Contains(address) { + return true + } + } } } } diff --git a/route/rule_geoip.go b/route/rule_geoip.go index 2b65d041..f171b8f2 100644 --- a/route/rule_geoip.go +++ b/route/rule_geoip.go @@ -3,8 +3,6 @@ package route import ( "strings" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" ) @@ -36,41 +34,26 @@ func NewGeoIPItem(router adapter.Router, logger log.Logger, isSource bool, codes func (r *GeoIPItem) Match(metadata *adapter.InboundContext) bool { geoReader := r.router.GeoIPReader() if geoReader == nil { - return r.match(metadata) + return false } if r.isSource { if metadata.SourceGeoIPCode == "" { metadata.SourceGeoIPCode = geoReader.Lookup(metadata.Source.Addr) } - } else { - if metadata.Destination.IsFqdn() { - return false - } - if metadata.GeoIPCode == "" { - metadata.GeoIPCode = geoReader.Lookup(metadata.Destination.Addr) - } - } - return r.match(metadata) -} - -func (r *GeoIPItem) match(metadata *adapter.InboundContext) bool { - if r.isSource { - if metadata.SourceGeoIPCode == "" { - if !N.IsPublicAddr(metadata.Source.Addr) { - metadata.SourceGeoIPCode = "private" - } - } return r.codeMap[metadata.SourceGeoIPCode] } else { - if metadata.Destination.IsFqdn() { - return false + if metadata.Destination.IsIP() { + if metadata.GeoIPCode == "" { + metadata.GeoIPCode = geoReader.Lookup(metadata.Destination.Addr) + } + return r.codeMap[metadata.GeoIPCode] } - if metadata.GeoIPCode == "" { - if !N.IsPublicAddr(metadata.Destination.Addr) { - metadata.GeoIPCode = "private" + for _, address := range metadata.DestinationAddresses { + if r.codeMap[geoReader.Lookup(address)] { + return true } } - return r.codeMap[metadata.GeoIPCode] + return false } }