From 7f841917483e21132f28dcb257d8fa219f51c64f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 10 Jul 2022 14:22:28 +0800 Subject: [PATCH] Minor fixes --- common/dialer/resolve_conn.go | 2 +- common/sniff/quic.go | 17 +++---------- common/sniff/sniff.go | 1 - dns/client.go | 2 ++ inbound/tun.go | 4 +-- option/inbound.go | 12 ++++----- option/types.go | 9 ++++++- outbound/shadowsocks.go | 4 +-- route/router.go | 4 +-- route/rule_dns.go | 46 +++++++++++++++++++++++++++-------- test/box_test.go | 5 ++++ test/docker_test.go | 2 ++ 12 files changed, 69 insertions(+), 39 deletions(-) diff --git a/common/dialer/resolve_conn.go b/common/dialer/resolve_conn.go index 42fc10b2..10ae32d5 100644 --- a/common/dialer/resolve_conn.go +++ b/common/dialer/resolve_conn.go @@ -37,7 +37,7 @@ func (w *ResolveUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { func (w *ResolveUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() - if destination.Family().IsFqdn() { + if destination.IsFqdn() { addresses, err := w.router.Lookup(context.Background(), destination.Fqdn, w.strategy) if err != nil { return err diff --git a/common/sniff/quic.go b/common/sniff/quic.go index 07bcbd6c..6a7df4e6 100644 --- a/common/sniff/quic.go +++ b/common/sniff/quic.go @@ -36,20 +36,9 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 { return nil, E.New("bad version") } - if versionNumber == qtls.Version2 { - if (typeByte&0x30)>>4 == 0b01 { - } else if (typeByte&0x30)>>4 != 0b10 { - // 0-rtt - } else { - return nil, E.New("bad packet type") - } - } else { - if (typeByte&0x30)>>4 == 0x0 { - } else if (typeByte&0x30)>>4 != 0x01 { - // 0-rtt - } else { - return nil, E.New("bad packet type") - } + packetType := (typeByte & 0x30) >> 4 + if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 { + return nil, E.New("bad packet type") } destConnIDLen, err := reader.ReadByte() diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index c33cded9..12055bed 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -28,7 +28,6 @@ func PeekPacket(ctx context.Context, packet []byte, sniffers ...PacketSniffer) ( for _, sniffer := range sniffers { sniffMetadata, err := sniffer(ctx, packet) if err != nil { - println(err.Error()) return nil, err } return sniffMetadata, nil diff --git a/dns/client.go b/dns/client.go index b9270b2c..0f593ec7 100644 --- a/dns/client.go +++ b/dns/client.go @@ -64,10 +64,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m } return nil, ErrNoRawSupport } + messageId := message.ID response, err := transport.Exchange(ctx, message) if err != nil { return nil, err } + response.ID = messageId if !c.disableCache { c.storeCache(question, response) } diff --git a/inbound/tun.go b/inbound/tun.go index 17b8e2c0..2665e43c 100644 --- a/inbound/tun.go +++ b/inbound/tun.go @@ -60,8 +60,8 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.Logger, tag s inboundOptions: options.InboundOptions, tunName: tunName, tunMTU: tunMTU, - inet4Address: netip.Prefix(options.Inet4Address), - inet6Address: netip.Prefix(options.Inet6Address), + inet4Address: options.Inet4Address.Build(), + inet6Address: options.Inet6Address.Build(), autoRoute: options.AutoRoute, hijackDNS: options.HijackDNS, }, nil diff --git a/option/inbound.go b/option/inbound.go index 1d04d457..2e8c190b 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -144,11 +144,11 @@ type ShadowsocksDestination struct { } type TunInboundOptions struct { - InterfaceName string `json:"interface_name,omitempty"` - MTU uint32 `json:"mtu,omitempty,omitempty"` - Inet4Address ListenPrefix `json:"inet4_address,omitempty"` - Inet6Address ListenPrefix `json:"inet6_address,omitempty"` - AutoRoute bool `json:"auto_route,omitempty"` - HijackDNS bool `json:"hijack_dns,omitempty"` + InterfaceName string `json:"interface_name,omitempty"` + MTU uint32 `json:"mtu,omitempty,omitempty"` + Inet4Address *ListenPrefix `json:"inet4_address,omitempty"` + Inet6Address *ListenPrefix `json:"inet6_address,omitempty"` + AutoRoute bool `json:"auto_route,omitempty"` + HijackDNS bool `json:"hijack_dns,omitempty"` InboundOptions } diff --git a/option/types.go b/option/types.go index 9825e604..c925cbe5 100644 --- a/option/types.go +++ b/option/types.go @@ -161,7 +161,7 @@ type ListenPrefix netip.Prefix func (p ListenPrefix) MarshalJSON() ([]byte, error) { prefix := netip.Prefix(p) if !prefix.IsValid() { - return json.Marshal("") + return json.Marshal(nil) } return json.Marshal(prefix.String()) } @@ -179,3 +179,10 @@ func (p *ListenPrefix) UnmarshalJSON(bytes []byte) error { *p = ListenPrefix(prefix) return nil } + +func (p *ListenPrefix) Build() netip.Prefix { + if p == nil { + return netip.Prefix{} + } + return netip.Prefix(*p) +} diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index d8af37ea..8c1d5108 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -72,11 +72,11 @@ func (h *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr) metadata.Outbound = h.tag metadata.Destination = destination h.logger.WithContext(ctx).Info("outbound packet connection to ", h.serverAddr) - outConn, err := h.dialer.ListenPacket(ctx, destination) + outConn, err := h.dialer.DialContext(ctx, "udp", h.serverAddr) if err != nil { return nil, err } - return h.method.DialPacketConn(&bufio.BindPacketConn{PacketConn: outConn, Addr: h.serverAddr.UDPAddr()}), nil + return h.method.DialPacketConn(outConn), nil } func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { diff --git a/route/router.go b/route/router.go index f373f4c7..130a4f57 100644 --- a/route/router.go +++ b/route/router.go @@ -482,7 +482,7 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def for i, rule := range r.rules { if rule.Match(&metadata) { detour := rule.Outbound() - r.logger.WithContext(ctx).Info("match[", i, "] ", rule.String(), " => ", detour) + r.logger.WithContext(ctx).Debug("match[", i, "] ", rule.String(), " => ", detour) if outbound, loaded := r.Outbound(detour); loaded { return outbound } @@ -501,7 +501,7 @@ func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport { for i, rule := range r.dnsRules { if rule.Match(metadata) { detour := rule.Outbound() - r.dnsLogger.WithContext(ctx).Info("match[", i, "] ", rule.String(), " => ", detour) + r.dnsLogger.WithContext(ctx).Debug("match[", i, "] ", rule.String(), " => ", detour) if transport, loaded := r.transportMap[detour]; loaded { return transport } diff --git a/route/rule_dns.go b/route/rule_dns.go index 7f5c0d75..c7f30ea7 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -41,8 +41,10 @@ func NewDNSRule(router adapter.Router, logger log.Logger, options option.DNSRule var _ adapter.Rule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { - items []RuleItem - outbound string + items []RuleItem + addressItems []RuleItem + allItems []RuleItem + outbound string } func NewDefaultDNSRule(router adapter.Router, logger log.Logger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { @@ -52,12 +54,14 @@ func NewDefaultDNSRule(router adapter.Router, logger log.Logger, options option. if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if options.Network != "" { switch options.Network { case C.NetworkTCP, C.NetworkUDP: item := NewNetworkItem(options.Network) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) default: return nil, E.New("invalid network: ", options.Network) } @@ -65,29 +69,35 @@ func NewDefaultDNSRule(router adapter.Router, logger log.Logger, options option. if len(options.Protocol) > 0 { item := NewProtocolItem(options.Protocol) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { item := NewDomainItem(options.Domain, options.DomainSuffix) - rule.items = append(rule.items, item) + rule.addressItems = append(rule.addressItems, item) + rule.allItems = append(rule.allItems, item) } if len(options.DomainKeyword) > 0 { item := NewDomainKeywordItem(options.DomainKeyword) - rule.items = append(rule.items, item) + rule.addressItems = append(rule.addressItems, item) + rule.allItems = append(rule.allItems, item) } if len(options.DomainRegex) > 0 { item, err := NewDomainRegexItem(options.DomainRegex) if err != nil { return nil, E.Cause(err, "domain_regex") } - rule.items = append(rule.items, item) + rule.addressItems = append(rule.addressItems, item) + rule.allItems = append(rule.allItems, item) } if len(options.Geosite) > 0 { item := NewGeositeItem(router, logger, options.Geosite) - rule.items = append(rule.items, item) + rule.addressItems = append(rule.addressItems, item) + rule.allItems = append(rule.allItems, item) } if len(options.SourceGeoIP) > 0 { item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.SourceIPCIDR) > 0 { item, err := NewIPCIDRItem(true, options.SourceIPCIDR) @@ -95,24 +105,28 @@ func NewDefaultDNSRule(router adapter.Router, logger log.Logger, options option. return nil, E.Cause(err, "source_ipcidr") } rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.SourcePort) > 0 { item := NewPortItem(true, options.SourcePort) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.Port) > 0 { item := NewPortItem(false, options.Port) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.Outbound) > 0 { item := NewOutboundRule(options.Outbound) rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } return rule, nil } func (r *DefaultDNSRule) Start() error { - for _, item := range r.items { + for _, item := range r.allItems { err := common.Start(item) if err != nil { return err @@ -122,7 +136,7 @@ func (r *DefaultDNSRule) Start() error { } func (r *DefaultDNSRule) Close() error { - for _, item := range r.items { + for _, item := range r.allItems { err := common.Close(item) if err != nil { return err @@ -132,7 +146,7 @@ func (r *DefaultDNSRule) Close() error { } func (r *DefaultDNSRule) UpdateGeosite() error { - for _, item := range r.items { + for _, item := range r.allItems { if geositeItem, isSite := item.(*GeositeItem); isSite { err := geositeItem.Update() if err != nil { @@ -149,6 +163,18 @@ func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { return false } } + if len(r.addressItems) > 0 { + var addressMatch bool + for _, item := range r.addressItems { + if item.Match(metadata) { + addressMatch = true + break + } + } + if !addressMatch { + return false + } + } return true } @@ -157,7 +183,7 @@ func (r *DefaultDNSRule) Outbound() string { } func (r *DefaultDNSRule) String() string { - return strings.Join(common.Map(r.items, F.ToString0[RuleItem]), " ") + return strings.Join(common.Map(r.allItems, F.ToString0[RuleItem]), " ") } var _ adapter.Rule = (*LogicalRule)(nil) diff --git a/test/box_test.go b/test/box_test.go index 757685c0..25a47758 100644 --- a/test/box_test.go +++ b/test/box_test.go @@ -7,14 +7,18 @@ import ( "github.com/sagernet/sing-box" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/control" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/protocol/socks" "github.com/stretchr/testify/require" + "time" ) func mkPort(t *testing.T) uint16 { + var lc net.ListenConfig + lc.Control = control.ReuseAddr() for { tcpListener, err := net.ListenTCP("tcp", nil) require.NoError(t, err) @@ -36,6 +40,7 @@ func startInstance(t *testing.T, options option.Options) { t.Cleanup(func() { instance.Close() }) + time.Sleep(time.Second) } func testSuit(t *testing.T, clientPort uint16, testPort uint16) { diff --git a/test/docker_test.go b/test/docker_test.go index 45db60c3..243c5486 100644 --- a/test/docker_test.go +++ b/test/docker_test.go @@ -11,6 +11,7 @@ import ( "github.com/docker/docker/client" "github.com/docker/go-connections/nat" "github.com/stretchr/testify/require" + "time" ) type DockerOptions struct { @@ -64,6 +65,7 @@ func startDockerContainer(t *testing.T, options DockerOptions) { go func() { attach.Reader.WriteTo(os.Stderr) }()*/ + time.Sleep(time.Second) } func cleanContainer(id string) error {