diff --git a/adapter/dns.go b/adapter/dns.go new file mode 100644 index 00000000..a8a54167 --- /dev/null +++ b/adapter/dns.go @@ -0,0 +1,22 @@ +package adapter + +import ( + "context" + "net/netip" + + C "github.com/sagernet/sing-box/constant" + + "golang.org/x/net/dns/dnsmessage" +) + +type DNSClient interface { + Exchange(ctx context.Context, transport DNSTransport, message *dnsmessage.Message) (*dnsmessage.Message, error) + Lookup(ctx context.Context, transport DNSTransport, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) +} + +type DNSTransport interface { + Service + Raw() bool + Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) + Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) +} diff --git a/adapter/router.go b/adapter/router.go index 79291a8d..d094113b 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -13,6 +13,7 @@ import ( type Router interface { Service Outbound(tag string) (Outbound, bool) + DefaultOutbound(network string) Outbound RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error GeoIPReader() *geoip.Reader diff --git a/dns/client.go b/dns/client.go new file mode 100644 index 00000000..d755134b --- /dev/null +++ b/dns/client.go @@ -0,0 +1,327 @@ +package dns + +import ( + "context" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/task" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + + "golang.org/x/net/dns/dnsmessage" +) + +const DefaultTTL = 600 + +var ( + ErrNoRawSupport = E.New("no raw query support by current transport") + ErrNotCached = E.New("not cached") +) + +var _ adapter.DNSClient = (*Client)(nil) + +type Client struct { + cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message] +} + +func NewClient() *Client { + return &Client{ + cache: cache.New[dnsmessage.Question, dnsmessage.Message](), + } +} + +func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message) (*dnsmessage.Message, error) { + if len(message.Questions) == 0 { + return nil, E.New("empty query") + } + question := message.Questions[0] + cachedAnswer, cached := c.cache.Load(question) + if cached { + cachedAnswer.ID = message.ID + return &cachedAnswer, nil + } + if !transport.Raw() { + if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA { + return c.exchangeToLookup(ctx, transport, message, question) + } + return nil, ErrNoRawSupport + } + response, err := transport.Exchange(ctx, message) + if err != nil { + return nil, err + } + c.cache.StoreWithExpire(question, *response, calculateExpire(message)) + return message, err +} + +func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { + dnsName, err := dnsmessage.NewName(domain) + if err != nil { + return nil, wrapError(err) + } + if transport.Raw() { + if strategy == C.DomainStrategyUseIPv4 { + return c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeA) + } else if strategy == C.DomainStrategyUseIPv6 { + return c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeAAAA) + } + var response4 []netip.Addr + var response6 []netip.Addr + err = task.Run(ctx, func() error { + response, err := c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeA) + if err != nil { + return err + } + response4 = response + return nil + }, func() error { + response, err := c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeAAAA) + if err != nil { + return err + } + response6 = response + return nil + }) + if len(response4) == 0 && len(response6) == 0 { + return nil, err + } + return sortAddresses(response4, response6, strategy), nil + } + if strategy == C.DomainStrategyUseIPv4 { + response, err := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + if err != ErrNotCached { + return response, err + } + } else if strategy == C.DomainStrategyUseIPv6 { + response, err := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }) + if err != ErrNotCached { + return response, err + } + } else { + response4, _ := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + response6, _ := c.questionCache(dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }) + if len(response4) > 0 || len(response6) > 0 { + return sortAddresses(response4, response6, strategy), nil + } + } + var rCode dnsmessage.RCode + response, err := transport.Lookup(ctx, domain, strategy) + if err != nil { + err = wrapError(err) + if rCodeError, isRCodeError := err.(RCodeError); !isRCodeError { + return nil, err + } else { + rCode = dnsmessage.RCode(rCodeError) + } + } + header := dnsmessage.Header{ + Response: true, + Authoritative: true, + RCode: rCode, + } + expire := time.Now().Add(time.Second * time.Duration(DefaultTTL)) + if strategy != C.DomainStrategyUseIPv6 { + question4 := dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + } + response4 := common.Filter(response, func(addr netip.Addr) bool { + return addr.Is4() || addr.Is4In6() + }) + message4 := dnsmessage.Message{ + Header: header, + Questions: []dnsmessage.Question{question4}, + } + if len(response4) > 0 { + for _, address := range response4 { + message4.Answers = append(message4.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: question4.Name, + Class: question4.Class, + TTL: DefaultTTL, + }, + Body: &dnsmessage.AResource{ + A: address.As4(), + }, + }) + } + } + c.cache.StoreWithExpire(question4, message4, expire) + } + if strategy != C.DomainStrategyUseIPv4 { + question6 := dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + } + response6 := common.Filter(response, func(addr netip.Addr) bool { + return addr.Is6() && !addr.Is4In6() + }) + message6 := dnsmessage.Message{ + Header: header, + Questions: []dnsmessage.Question{question6}, + } + if len(response6) > 0 { + for _, address := range response6 { + message6.Answers = append(message6.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: question6.Name, + Class: question6.Class, + TTL: DefaultTTL, + }, + Body: &dnsmessage.AAAAResource{ + AAAA: address.As16(), + }, + }) + } + } + c.cache.StoreWithExpire(question6, message6, expire) + } + return response, err +} + +func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.DomainStrategy) []netip.Addr { + if strategy == C.DomainStrategyPreferIPv6 { + return append(response6, response4...) + } else { + return append(response4, response6...) + } +} + +func calculateExpire(message *dnsmessage.Message) time.Time { + timeToLive := DefaultTTL + for _, answer := range message.Answers { + if int(answer.Header.TTL) < timeToLive { + timeToLive = int(answer.Header.TTL) + } + } + return time.Now().Add(time.Second * time.Duration(timeToLive)) +} + +func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) { + domain := question.Name.String() + var strategy C.DomainStrategy + if question.Type == dnsmessage.TypeA { + strategy = C.DomainStrategyUseIPv4 + } else { + strategy = C.DomainStrategyUseIPv6 + } + var rCode dnsmessage.RCode + result, err := c.Lookup(ctx, transport, domain, strategy) + if err != nil { + err = wrapError(err) + if rCodeError, isRCodeError := err.(RCodeError); !isRCodeError { + return nil, err + } else { + rCode = dnsmessage.RCode(rCodeError) + } + } + response := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: message.ID, + RCode: rCode, + RecursionAvailable: true, + RecursionDesired: true, + Response: true, + }, + Questions: message.Questions, + } + for _, address := range result { + var resource dnsmessage.Resource + resource.Header = dnsmessage.ResourceHeader{ + Name: question.Name, + Class: question.Class, + TTL: DefaultTTL, + } + if address.Is4() || address.Is4In6() { + resource.Body = &dnsmessage.AResource{ + A: address.As4(), + } + } else { + resource.Body = &dnsmessage.AAAAResource{ + AAAA: address.As16(), + } + } + } + return &response, nil +} + +func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name dnsmessage.Name, qType dnsmessage.Type) ([]netip.Addr, error) { + question := dnsmessage.Question{ + Name: name, + Type: qType, + Class: dnsmessage.ClassINET, + } + cachedAddresses, err := c.questionCache(question) + if err != ErrNotCached { + return cachedAddresses, err + } + message := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: 0, + RecursionDesired: true, + }, + Questions: []dnsmessage.Question{question}, + } + response, err := c.Exchange(ctx, transport, &message) + if err != nil { + return nil, err + } + return messageToAddresses(response) +} + +func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, error) { + response, cached := c.cache.Load(question) + if !cached { + return nil, ErrNotCached + } + return messageToAddresses(&response) +} + +func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) { + if response.RCode != dnsmessage.RCodeSuccess { + return nil, RCodeError(response.RCode) + } + addresses := make([]netip.Addr, 0, len(response.Answers)) + for _, answer := range response.Answers { + switch resource := answer.Body.(type) { + case *dnsmessage.AResource: + addresses = append(addresses, netip.AddrFrom4(resource.A)) + case *dnsmessage.AAAAResource: + addresses = append(addresses, netip.AddrFrom16(resource.AAAA)) + } + } + return addresses, nil +} + +func wrapError(err error) error { + if dnsErr, isDNSError := err.(*net.DNSError); isDNSError { + if dnsErr.IsNotFound { + return RCodeNameError + } + } + return err +} diff --git a/dns/rcode.go b/dns/rcode.go new file mode 100644 index 00000000..5b7e52cc --- /dev/null +++ b/dns/rcode.go @@ -0,0 +1,33 @@ +package dns + +import F "github.com/sagernet/sing/common/format" + +const ( + RCodeSuccess RCodeError = 0 // NoError + RCodeFormatError RCodeError = 1 // FormErr + RCodeServerFailure RCodeError = 2 // ServFail + RCodeNameError RCodeError = 3 // NXDomain + RCodeNotImplemented RCodeError = 4 // NotImp + RCodeRefused RCodeError = 5 // Refused +) + +type RCodeError uint16 + +func (e RCodeError) Error() string { + switch e { + case RCodeSuccess: + return "success" + case RCodeFormatError: + return "format error" + case RCodeServerFailure: + return "server failure" + case RCodeNameError: + return "name error" + case RCodeNotImplemented: + return "not implemented" + case RCodeRefused: + return "refused" + default: + return F.ToString("unknown error: ", uint16(e)) + } +} diff --git a/dns/transport.go b/dns/transport.go index cbb4526e..e6a2c883 100644 --- a/dns/transport.go +++ b/dns/transport.go @@ -2,17 +2,40 @@ package dns import ( "context" - "net/netip" + "net/url" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - - "golang.org/x/net/dns/dnsmessage" + "github.com/sagernet/sing-box/log" ) -type Transport interface { - adapter.Service - Raw() bool - Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) - Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) +func NewTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, address string) (adapter.DNSTransport, error) { + if address == "local" { + return NewLocalTransport(), nil + } + serverURL, err := url.Parse(address) + if err != nil { + return nil, err + } + host := serverURL.Hostname() + port := serverURL.Port() + if port == "" { + port = "53" + } + destination := M.ParseSocksaddrHostPortStr(host, port) + switch serverURL.Scheme { + case "", "udp": + return NewUDPTransport(ctx, dialer, logger, destination), nil + case "tcp": + return NewTCPTransport(ctx, dialer, logger, destination), nil + case "tls": + return NewTLSTransport(ctx, dialer, logger, destination), nil + case "https": + return NewHTTPSTransport(dialer, serverURL.String()), nil + default: + return nil, E.New("unknown dns scheme: " + serverURL.Scheme) + } } diff --git a/dns/transport_base.go b/dns/transport_base.go new file mode 100644 index 00000000..a52a36b0 --- /dev/null +++ b/dns/transport_base.go @@ -0,0 +1,46 @@ +package dns + +import ( + "context" + "net/netip" + "os" + "sync" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" +) + +type myTransportAdapter struct { + ctx context.Context + dialer N.Dialer + logger log.Logger + destination M.Socksaddr + done chan struct{} + access sync.RWMutex + connection *dnsConnection +} + +func (t *myTransportAdapter) Start() error { + return nil +} + +func (t *myTransportAdapter) Close() error { + select { + case <-t.done: + return os.ErrClosed + default: + } + close(t.done) + return nil +} + +func (t *myTransportAdapter) Raw() bool { + return true +} + +func (t *myTransportAdapter) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { + return nil, os.ErrInvalid +} diff --git a/dns/transport_https.go b/dns/transport_https.go index 831f59fb..039d03a5 100644 --- a/dns/transport_https.go +++ b/dns/transport_https.go @@ -13,6 +13,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "golang.org/x/net/dns/dnsmessage" @@ -20,7 +21,7 @@ import ( const dnsMimeType = "application/dns-message" -var _ Transport = (*HTTPSTransport)(nil) +var _ adapter.DNSTransport = (*HTTPSTransport)(nil) type HTTPSTransport struct { destination string diff --git a/dns/transport_local.go b/dns/transport_local.go index fe9459d5..c5e23f4a 100644 --- a/dns/transport_local.go +++ b/dns/transport_local.go @@ -9,21 +9,22 @@ import ( "github.com/sagernet/sing/common" + "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "golang.org/x/net/dns/dnsmessage" ) -var LocalTransportConstructor func() Transport +var LocalTransportConstructor func() adapter.DNSTransport -func NewLocalTransport() Transport { +func NewLocalTransport() adapter.DNSTransport { if LocalTransportConstructor != nil { return LocalTransportConstructor() } return &LocalTransport{} } -var _ Transport = (*LocalTransport)(nil) +var _ adapter.DNSTransport = (*LocalTransport)(nil) type LocalTransport struct { resolver net.Resolver diff --git a/dns/transport_tcp.go b/dns/transport_tcp.go index 9ab73fdb..e57d6ad7 100644 --- a/dns/transport_tcp.go +++ b/dns/transport_tcp.go @@ -4,7 +4,6 @@ import ( "context" "encoding/binary" "net" - "net/netip" "os" "sync" @@ -15,52 +14,30 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" - C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "golang.org/x/net/dns/dnsmessage" ) -var _ Transport = (*TCPTransport)(nil) +var _ adapter.DNSTransport = (*TCPTransport)(nil) type TCPTransport struct { - ctx context.Context - dialer N.Dialer - logger log.Logger - destination M.Socksaddr - done chan struct{} - access sync.RWMutex - connection *dnsConnection + myTransportAdapter } func NewTCPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *TCPTransport { return &TCPTransport{ - ctx: ctx, - dialer: dialer, - logger: logger, - destination: destination, - done: make(chan struct{}), + myTransportAdapter{ + ctx: ctx, + dialer: dialer, + logger: logger, + destination: destination, + done: make(chan struct{}), + }, } } -func (t *TCPTransport) Start() error { - return nil -} - -func (t *TCPTransport) Close() error { - select { - case <-t.done: - return os.ErrClosed - default: - } - close(t.done) - return nil -} - -func (t *TCPTransport) Raw() bool { - return true -} - func (t *TCPTransport) offer() (*dnsConnection, error) { t.access.RLock() connection := t.connection @@ -207,7 +184,3 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *dnsmessage.Message return nil, ctx.Err() } } - -func (t *TCPTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { - return nil, os.ErrInvalid -} diff --git a/dns/transport_tls.go b/dns/transport_tls.go index 0f41b33f..2a2fb17e 100644 --- a/dns/transport_tls.go +++ b/dns/transport_tls.go @@ -4,9 +4,7 @@ import ( "context" "crypto/tls" "encoding/binary" - "net/netip" "os" - "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -15,52 +13,30 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" - C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "golang.org/x/net/dns/dnsmessage" ) -var _ Transport = (*TLSTransport)(nil) +var _ adapter.DNSTransport = (*TLSTransport)(nil) type TLSTransport struct { - ctx context.Context - dialer N.Dialer - logger log.Logger - destination M.Socksaddr - done chan struct{} - access sync.RWMutex - connection *dnsConnection + myTransportAdapter } func NewTLSTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *TLSTransport { return &TLSTransport{ - ctx: ctx, - dialer: dialer, - logger: logger, - destination: destination, - done: make(chan struct{}), + myTransportAdapter{ + ctx: ctx, + dialer: dialer, + logger: logger, + destination: destination, + done: make(chan struct{}), + }, } } -func (t *TLSTransport) Start() error { - return nil -} - -func (t *TLSTransport) Close() error { - select { - case <-t.done: - return os.ErrClosed - default: - } - close(t.done) - return nil -} - -func (t *TLSTransport) Raw() bool { - return true -} - func (t *TLSTransport) offer(ctx context.Context) (*dnsConnection, error) { t.access.RLock() connection := t.connection @@ -207,7 +183,3 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *dnsmessage.Message return nil, ctx.Err() } } - -func (t *TLSTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { - return nil, os.ErrInvalid -} diff --git a/dns/transport_udp.go b/dns/transport_udp.go index fae4e2f1..81f64142 100644 --- a/dns/transport_udp.go +++ b/dns/transport_udp.go @@ -2,9 +2,7 @@ package dns import ( "context" - "net/netip" "os" - "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -12,52 +10,30 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" - C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "golang.org/x/net/dns/dnsmessage" ) -var _ Transport = (*UDPTransport)(nil) +var _ adapter.DNSTransport = (*UDPTransport)(nil) type UDPTransport struct { - ctx context.Context - dialer N.Dialer - logger log.Logger - destination M.Socksaddr - done chan struct{} - access sync.RWMutex - connection *dnsConnection + myTransportAdapter } func NewUDPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *UDPTransport { return &UDPTransport{ - ctx: ctx, - dialer: dialer, - logger: logger, - destination: destination, - done: make(chan struct{}), + myTransportAdapter{ + ctx: ctx, + dialer: dialer, + logger: logger, + destination: destination, + done: make(chan struct{}), + }, } } -func (t *UDPTransport) Start() error { - return nil -} - -func (t *UDPTransport) Close() error { - select { - case <-t.done: - return os.ErrClosed - default: - } - close(t.done) - return nil -} - -func (t *UDPTransport) Raw() bool { - return true -} - func (t *UDPTransport) offer() (*dnsConnection, error) { t.access.RLock() connection := t.connection @@ -184,7 +160,3 @@ func (t *UDPTransport) Exchange(ctx context.Context, message *dnsmessage.Message return nil, ctx.Err() } } - -func (t *UDPTransport) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { - return nil, os.ErrInvalid -} diff --git a/go.mod b/go.mod index de8c92c1..a89d9e24 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/goccy/go-json v0.9.8 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/oschwald/maxminddb-golang v1.9.0 - github.com/sagernet/sing v0.0.0-20220706103716-44ec149b1efc + github.com/sagernet/sing v0.0.0-20220706131532-6d16497f03a6 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 github.com/sirupsen/logrus v1.8.1 github.com/spf13/cobra v1.5.0 diff --git a/go.sum b/go.sum index 7ee0ea48..ffce45a4 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sagernet/sing v0.0.0-20220706103716-44ec149b1efc h1:TpmuXk61HoJHOY6ScS3t2Bz41HTbuPnffsf6QdnQoSg= -github.com/sagernet/sing v0.0.0-20220706103716-44ec149b1efc/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= +github.com/sagernet/sing v0.0.0-20220706131532-6d16497f03a6 h1:NKDjOKPHP4JOrYomj2Q/tvKDWLmCNLHNQSPZLE5o3I4= +github.com/sagernet/sing v0.0.0-20220706131532-6d16497f03a6/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= diff --git a/option/config.go b/option/config.go index 8608dc27..0871df22 100644 --- a/option/config.go +++ b/option/config.go @@ -12,6 +12,7 @@ type _Options struct { Log *LogOption `json:"log,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` + DNS *DNSOptions `json:"dns,omitempty"` Route *RouteOptions `json:"route,omitempty"` } diff --git a/option/dns.go b/option/dns.go new file mode 100644 index 00000000..17b463ab --- /dev/null +++ b/option/dns.go @@ -0,0 +1,12 @@ +package option + +type DNSOptions struct { + Servers []DNSServerOptions `json:"servers,omitempty"` +} + +type DNSServerOptions struct { + Tag string `json:"tag,omitempty"` + Address string `json:"address"` + Detour string `json:"detour,omitempty"` + AddressResolver string `json:"address_resolver,omitempty"` +} diff --git a/option/route.go b/option/route.go index a55e4140..2d8ed17f 100644 --- a/option/route.go +++ b/option/route.go @@ -101,9 +101,7 @@ type DefaultRule struct { IPCIDR Listable[string] `json:"ip_cidr,omitempty"` SourcePort Listable[uint16] `json:"source_port,omitempty"` Port Listable[uint16] `json:"port,omitempty"` - // ProcessName Listable[string] `json:"process_name,omitempty"` - // ProcessPath Listable[string] `json:"process_path,omitempty"` - Outbound string `json:"outbound,omitempty"` + Outbound string `json:"outbound,omitempty"` } func (r DefaultRule) IsValid() bool { diff --git a/option/types.go b/option/types.go index 86ebe728..f2567891 100644 --- a/option/types.go +++ b/option/types.go @@ -90,3 +90,47 @@ func (l *Listable[T]) UnmarshalJSON(content []byte) error { *l = []T{singleItem} return nil } + +type DomainStrategy C.DomainStrategy + +func (s DomainStrategy) MarshalJSON() ([]byte, error) { + var value string + switch C.DomainStrategy(s) { + case C.DomainStrategyAsIS: + value = "AsIS" + case C.DomainStrategyPreferIPv4: + value = "PreferIPv4" + case C.DomainStrategyPreferIPv6: + value = "PreferIPv6" + case C.DomainStrategyUseIPv4: + value = "UseIPv4" + case C.DomainStrategyUseIPv6: + value = "UseIPv6" + default: + return nil, E.New("unknown domain strategy: ", s) + } + return json.Marshal(value) +} + +func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error { + var value string + err := json.Unmarshal(bytes, &value) + if err != nil { + return err + } + switch value { + case "AsIS": + *s = DomainStrategy(C.DomainStrategyAsIS) + case "PreferIPv4": + *s = DomainStrategy(C.DomainStrategyPreferIPv4) + case "PreferIPv6": + *s = DomainStrategy(C.DomainStrategyPreferIPv6) + case "UseIPv4": + *s = DomainStrategy(C.DomainStrategyUseIPv4) + case "UseIPv6": + *s = DomainStrategy(C.DomainStrategyUseIPv6) + default: + return E.New("unknown domain strategy: ", value) + } + return nil +} diff --git a/outbound/dialer/default.go b/outbound/dialer/default.go index 251ec133..0332e309 100644 --- a/outbound/dialer/default.go +++ b/outbound/dialer/default.go @@ -20,7 +20,7 @@ type defaultDialer struct { net.ListenConfig } -func newDefault(options option.DialerOptions) N.Dialer { +func NewDefault(options option.DialerOptions) N.Dialer { var dialer net.Dialer var listener net.ListenConfig if options.BindInterface != "" { diff --git a/outbound/dialer/detour.go b/outbound/dialer/detour.go index a88b1dad..bc79d35f 100644 --- a/outbound/dialer/detour.go +++ b/outbound/dialer/detour.go @@ -10,27 +10,31 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/option" ) type detourDialer struct { router adapter.Router - options option.DialerOptions + detour string dialer N.Dialer initOnce sync.Once initErr error } -func newDetour(router adapter.Router, options option.DialerOptions) N.Dialer { - return &detourDialer{router: router, options: options} +func NewDetour(router adapter.Router, detour string) N.Dialer { + return &detourDialer{router: router, detour: detour} +} + +func (d *detourDialer) Start() error { + _, err := d.Dialer() + return err } func (d *detourDialer) Dialer() (N.Dialer, error) { d.initOnce.Do(func() { var loaded bool - d.dialer, loaded = d.router.Outbound(d.options.Detour) + d.dialer, loaded = d.router.Outbound(d.detour) if !loaded { - d.initErr = E.New("outbound detour not found: ", d.options.Detour) + d.initErr = E.New("outbound detour not found: ", d.detour) } }) return d.dialer, d.initErr diff --git a/outbound/dialer/dialer.go b/outbound/dialer/dialer.go index 78d52a3c..411163ba 100644 --- a/outbound/dialer/dialer.go +++ b/outbound/dialer/dialer.go @@ -11,12 +11,12 @@ import ( func New(router adapter.Router, options option.DialerOptions) N.Dialer { var dialer N.Dialer if options.Detour == "" { - dialer = newDefault(options) + dialer = NewDefault(options) } else { - dialer = newDetour(router, options) + dialer = NewDetour(router, options.Detour) } if options.OverrideOptions.IsValid() { - dialer = newOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions)) + dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions)) } return dialer } diff --git a/outbound/dialer/override.go b/outbound/dialer/override.go index 8ee4e569..4da9a122 100644 --- a/outbound/dialer/override.go +++ b/outbound/dialer/override.go @@ -22,7 +22,7 @@ type overrideDialer struct { uotEnabled bool } -func newOverride(upstream N.Dialer, options option.OverrideStreamOptions) N.Dialer { +func NewOverride(upstream N.Dialer, options option.OverrideStreamOptions) N.Dialer { return &overrideDialer{ upstream, options.TLS, diff --git a/outbound/dialer/protect.go b/outbound/dialer/protect.go index de789b55..9fd05d21 100644 --- a/outbound/dialer/protect.go +++ b/outbound/dialer/protect.go @@ -5,7 +5,6 @@ package dialer import ( "syscall" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" ) @@ -42,6 +41,6 @@ func ProtectPath(protectPath string) control.Func { err := conn.Control(func(fd uintptr) { innerErr = sendAncillaryFileDescriptors(protectPath, []int{int(fd)}) }) - return common.AnyError(innerErr, err) + return E.Errors(innerErr, err) } } diff --git a/route/router.go b/route/router.go index ffa836d4..f3579623 100644 --- a/route/router.go +++ b/route/router.go @@ -191,6 +191,14 @@ func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { return outbound, loaded } +func (r *Router) DefaultOutbound(network string) adapter.Outbound { + if network == C.NetworkTCP { + return r.defaultOutboundForConnection + } else { + return r.defaultOutboundForPacketConnection + } +} + func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if metadata.SniffEnabled { _buffer := buf.StackNew()