package route import ( "context" "io" "net" "net/http" "os" "path/filepath" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geosite" "github.com/sagernet/sing-box/common/sniff" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" ) var _ adapter.Router = (*Router)(nil) type Router struct { ctx context.Context logger log.Logger outboundByTag map[string]adapter.Outbound rules []adapter.Rule defaultDetour string defaultOutboundForConnection adapter.Outbound defaultOutboundForPacketConnection adapter.Outbound needGeoIPDatabase bool needGeositeDatabase bool geoIPOptions option.GeoIPOptions geositeOptions option.GeositeOptions geoIPReader *geoip.Reader geositeReader *geosite.Reader } func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptions) (*Router, error) { router := &Router{ ctx: ctx, logger: logger.WithPrefix("router: "), outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), needGeoIPDatabase: hasGeoRule(options.Rules, isGeoIPRule), needGeositeDatabase: hasGeoRule(options.Rules, isGeositeRule), geoIPOptions: common.PtrValueOrDefault(options.GeoIP), defaultDetour: options.DefaultDetour, } for i, ruleOptions := range options.Rules { rule, err := NewRule(router, logger, ruleOptions) if err != nil { return nil, E.Cause(err, "parse rule[", i, "]") } router.rules = append(router.rules, rule) } return router, nil } func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error { outboundByTag := make(map[string]adapter.Outbound) for _, detour := range outbounds { outboundByTag[detour.Tag()] = detour } var defaultOutboundForConnection adapter.Outbound var defaultOutboundForPacketConnection adapter.Outbound if r.defaultDetour != "" { detour, loaded := outboundByTag[r.defaultDetour] if !loaded { return E.New("default detour not found: ", r.defaultDetour) } if common.Contains(detour.Network(), C.NetworkTCP) { defaultOutboundForConnection = detour } if common.Contains(detour.Network(), C.NetworkUDP) { defaultOutboundForPacketConnection = detour } } var index, packetIndex int if defaultOutboundForConnection == nil { for i, detour := range outbounds { if common.Contains(detour.Network(), C.NetworkTCP) { index = i defaultOutboundForConnection = detour break } } } if defaultOutboundForPacketConnection == nil { for i, detour := range outbounds { if common.Contains(detour.Network(), C.NetworkUDP) { packetIndex = i defaultOutboundForPacketConnection = detour break } } } if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil { detour := defaultOutbound() if defaultOutboundForConnection == nil { defaultOutboundForConnection = detour } if defaultOutboundForPacketConnection == nil { defaultOutboundForPacketConnection = detour } } if defaultOutboundForConnection != defaultOutboundForPacketConnection { var description string if defaultOutboundForConnection.Tag() != "" { description = defaultOutboundForConnection.Tag() } else { description = F.ToString(index) } var packetDescription string if defaultOutboundForPacketConnection.Tag() != "" { packetDescription = defaultOutboundForPacketConnection.Tag() } else { packetDescription = F.ToString(packetIndex) } r.logger.Info("using ", defaultOutboundForConnection.Type(), "[", description, "] as default outbound for connection") r.logger.Info("using ", defaultOutboundForPacketConnection.Type(), "[", packetDescription, "] as default outbound for packet connection") } r.defaultOutboundForConnection = defaultOutboundForConnection r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection r.outboundByTag = outboundByTag return nil } func (r *Router) Start() error { if r.needGeoIPDatabase { err := r.prepareGeoIPDatabase() if err != nil { return err } } if r.needGeositeDatabase { err := r.prepareGeositeDatabase() if err != nil { return err } } for _, rule := range r.rules { err := rule.Start() if err != nil { return err } } if r.needGeositeDatabase { for _, rule := range r.rules { err := rule.UpdateGeosite() if err != nil { r.logger.Error("failed to initialize geosite: ", err) } } err := common.Close(r.geositeReader) if err != nil { return err } } return nil } func (r *Router) Close() error { return common.Close( common.PtrOrNil(r.geoIPReader), ) } func (r *Router) GeoIPReader() *geoip.Reader { return r.geoIPReader } func (r *Router) GeositeReader() *geosite.Reader { return r.geositeReader } func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { outbound, loaded := r.outboundByTag[tag] return outbound, loaded } func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if metadata.SniffEnabled { _buffer := buf.StackNew() defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() reader := io.TeeReader(conn, buffer) sniffMetadata, err := sniff.PeekStream(ctx, reader, sniff.TLSClientHello, sniff.HTTPHost) if err == nil { metadata.Protocol = sniffMetadata.Protocol metadata.Domain = sniffMetadata.Domain if metadata.SniffOverrideDestination && sniff.IsDomainName(metadata.Domain) { metadata.Destination.Fqdn = metadata.Domain } if metadata.Domain != "" { r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) } else { r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol) } } if !buffer.IsEmpty() { conn = bufio.NewCachedConn(conn, buffer) } } detour := r.match(ctx, metadata, r.defaultOutboundForConnection) if !common.Contains(detour.Network(), C.NetworkTCP) { conn.Close() return E.New("missing supported outbound, closing connection") } return detour.NewConnection(ctx, conn, metadata.Destination) } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { if metadata.SniffEnabled { _buffer := buf.StackNewPacket() defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() _, err := conn.ReadPacket(buffer) if err != nil { return err } sniffMetadata, err := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.QUICClientHello) originDestination := metadata.Destination if err == nil { metadata.Protocol = sniffMetadata.Protocol metadata.Domain = sniffMetadata.Domain if metadata.SniffOverrideDestination && sniff.IsDomainName(metadata.Domain) { metadata.Destination.Fqdn = metadata.Domain } if metadata.Domain != "" { r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) } else { r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol) } } conn = bufio.NewCachedPacketConn(conn, buffer, originDestination) } detour := r.match(ctx, metadata, r.defaultOutboundForPacketConnection) if !common.Contains(detour.Network(), C.NetworkUDP) { conn.Close() return E.New("missing supported outbound, closing packet connection") } return detour.NewPacketConnection(ctx, conn, metadata.Destination) } func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, defaultOutbound adapter.Outbound) adapter.Outbound { for i, rule := range r.rules { if rule.Match(&metadata) { detour := rule.Outbound() r.logger.WithContext(ctx).Info("match[", i, "] ", rule.String(), " => ", detour) if outbound, loaded := r.Outbound(detour); loaded { return outbound } r.logger.WithContext(ctx).Error("outbound not found: ", detour) } } r.logger.WithContext(ctx).Info("no match") return defaultOutbound } func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { for _, rule := range rules { switch rule.Type { case C.RuleTypeDefault: if cond(rule.DefaultOptions) { return true } case C.RuleTypeLogical: for _, subRule := range rule.LogicalOptions.Rules { if cond(subRule) { return true } } } } return false } func isGeoIPRule(rule option.DefaultRule) bool { return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) } func isGeositeRule(rule option.DefaultRule) bool { return len(rule.Geosite) > 0 } func notPrivateNode(code string) bool { return code != "private" } func (r *Router) prepareGeoIPDatabase() error { var geoPath string if r.geoIPOptions.Path != "" { geoPath = r.geoIPOptions.Path } else { geoPath = "geoip.db" if foundPath, loaded := C.FindPath(geoPath); loaded { geoPath = foundPath } } if !rw.FileExists(geoPath) { r.logger.Warn("geoip database not exists: ", geoPath) var err error for attempts := 0; attempts < 3; attempts++ { err = r.downloadGeoIPDatabase(geoPath) if err == nil { break } r.logger.Error("download geoip database: ", err) os.Remove(geoPath) time.Sleep(10 * time.Second) } if err != nil { return err } } geoReader, codes, err := geoip.Open(geoPath) if err != nil { return E.Cause(err, "open geoip database") } r.logger.Info("loaded geoip database: ", len(codes), " codes") r.geoIPReader = geoReader return nil } func (r *Router) prepareGeositeDatabase() error { var geoPath string if r.geositeOptions.Path != "" { geoPath = r.geoIPOptions.Path } else { geoPath = "geosite.db" if foundPath, loaded := C.FindPath(geoPath); loaded { geoPath = foundPath } } if !rw.FileExists(geoPath) { r.logger.Warn("geosite database not exists: ", geoPath) var err error for attempts := 0; attempts < 3; attempts++ { err = r.downloadGeositeDatabase(geoPath) if err == nil { break } r.logger.Error("download geosite database: ", err) os.Remove(geoPath) time.Sleep(10 * time.Second) } if err != nil { return err } } geoReader, codes, err := geosite.Open(geoPath) if err == nil { r.logger.Info("loaded geosite database: ", len(codes), " codes") r.geositeReader = geoReader } else { return E.Cause(err, "open geosite database") } return nil } func (r *Router) downloadGeoIPDatabase(savePath string) error { var downloadURL string if r.geoIPOptions.DownloadURL != "" { downloadURL = r.geoIPOptions.DownloadURL } else { downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db" } r.logger.Info("downloading geoip database") var detour adapter.Outbound if r.geoIPOptions.DownloadDetour != "" { outbound, loaded := r.Outbound(r.geoIPOptions.DownloadDetour) if !loaded { return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) } detour = outbound } else { detour = r.defaultOutboundForConnection } if parentDir := filepath.Dir(savePath); parentDir != "" { os.MkdirAll(parentDir, 0o755) } saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return E.Cause(err, "open output file: ", downloadURL) } defer saveFile.Close() httpClient := &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: true, TLSHandshakeTimeout: 5 * time.Second, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) }, }, } response, err := httpClient.Get(downloadURL) if err != nil { return err } defer response.Body.Close() _, err = io.Copy(saveFile, response.Body) return err } func (r *Router) downloadGeositeDatabase(savePath string) error { var downloadURL string if r.geositeOptions.DownloadURL != "" { downloadURL = r.geositeOptions.DownloadURL } else { downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db" } r.logger.Info("downloading geoip database") var detour adapter.Outbound if r.geositeOptions.DownloadDetour != "" { outbound, loaded := r.Outbound(r.geositeOptions.DownloadDetour) if !loaded { return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) } detour = outbound } else { detour = r.defaultOutboundForConnection } if parentDir := filepath.Dir(savePath); parentDir != "" { os.MkdirAll(parentDir, 0o755) } saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return E.Cause(err, "open output file: ", downloadURL) } defer saveFile.Close() httpClient := &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: true, TLSHandshakeTimeout: 5 * time.Second, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) }, }, } response, err := httpClient.Get(downloadURL) if err != nil { return err } defer response.Body.Close() _, err = io.Copy(saveFile, response.Body) return err }