diff --git a/adapter/experimental.go b/adapter/experimental.go index e223aa50..d290a8f2 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -10,6 +10,7 @@ import ( type ClashServer interface { Service + PreStarter Mode() string StoreSelected() bool CacheFile() ClashCacheFile diff --git a/adapter/prestart.go b/adapter/prestart.go new file mode 100644 index 00000000..c1b5f581 --- /dev/null +++ b/adapter/prestart.go @@ -0,0 +1,15 @@ +package adapter + +type PreStarter interface { + PreStart() error +} + +func PreStart(starter any) error { + if preService, ok := starter.(PreStarter); ok { + err := preService.PreStart() + if err != nil { + return err + } + } + return nil +} diff --git a/box.go b/box.go index adf72bbb..fa4393ca 100644 --- a/box.go +++ b/box.go @@ -25,16 +25,16 @@ import ( var _ adapter.Service = (*Box)(nil) type Box struct { - createdAt time.Time - router adapter.Router - inbounds []adapter.Inbound - outbounds []adapter.Outbound - logFactory log.Factory - logger log.ContextLogger - logFile *os.File - clashServer adapter.ClashServer - v2rayServer adapter.V2RayServer - done chan struct{} + createdAt time.Time + router adapter.Router + inbounds []adapter.Inbound + outbounds []adapter.Outbound + logFactory log.Factory + logger log.ContextLogger + logFile *os.File + preServices map[string]adapter.Service + postServices map[string]adapter.Service + done chan struct{} } func New(ctx context.Context, options option.Options, platformInterface platform.Interface) (*Box, error) { @@ -166,37 +166,57 @@ func New(ctx context.Context, options option.Options, platformInterface platform if err != nil { return nil, err } - - var clashServer adapter.ClashServer - var v2rayServer adapter.V2RayServer + preServices := make(map[string]adapter.Service) + postServices := make(map[string]adapter.Service) if needClashAPI { - clashServer, err = experimental.NewClashServer(router, observableLogFactory, common.PtrValueOrDefault(options.Experimental.ClashAPI)) + clashServer, err := experimental.NewClashServer(router, observableLogFactory, common.PtrValueOrDefault(options.Experimental.ClashAPI)) if err != nil { return nil, E.Cause(err, "create clash api server") } router.SetClashServer(clashServer) + preServices["clash api"] = clashServer } if needV2RayAPI { - v2rayServer, err = experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(options.Experimental.V2RayAPI)) + v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(options.Experimental.V2RayAPI)) if err != nil { return nil, E.Cause(err, "create v2ray api server") } router.SetV2RayServer(v2rayServer) + preServices["v2ray api"] = v2rayServer } return &Box{ - router: router, - inbounds: inbounds, - outbounds: outbounds, - createdAt: createdAt, - logFactory: logFactory, - logger: logFactory.Logger(), - logFile: logFile, - clashServer: clashServer, - v2rayServer: v2rayServer, - done: make(chan struct{}), + router: router, + inbounds: inbounds, + outbounds: outbounds, + createdAt: createdAt, + logFactory: logFactory, + logger: logFactory.Logger(), + logFile: logFile, + preServices: preServices, + postServices: postServices, + done: make(chan struct{}), }, nil } +func (s *Box) PreStart() error { + err := s.preStart() + if err != nil { + // TODO: remove catch error + defer func() { + v := recover() + if v != nil { + log.Error(E.Cause(err, "origin error")) + debug.PrintStack() + panic("panic on early close: " + fmt.Sprint(v)) + } + }() + s.Close() + return err + } + s.logger.Info("sing-box pre-started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)") + return nil +} + func (s *Box) Start() error { err := s.start() if err != nil { @@ -210,21 +230,17 @@ func (s *Box) Start() error { } }() s.Close() + return err } - return err + s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)") + return nil } -func (s *Box) start() error { - if s.clashServer != nil { - err := s.clashServer.Start() +func (s *Box) preStart() error { + for serviceName, service := range s.preServices { + err := adapter.PreStart(service) if err != nil { - return E.Cause(err, "start clash api server") - } - } - if s.v2rayServer != nil { - err := s.v2rayServer.Start() - if err != nil { - return E.Cause(err, "start v2ray api server") + return E.Cause(err, "pre-start ", serviceName) } } for i, out := range s.outbounds { @@ -241,10 +257,20 @@ func (s *Box) start() error { } } } - err := s.router.Start() + return s.router.Start() +} + +func (s *Box) start() error { + err := s.preStart() if err != nil { return err } + for serviceName, service := range s.preServices { + err = service.Start() + if err != nil { + return E.Cause(err, "start ", serviceName) + } + } for i, in := range s.inbounds { err = in.Start() if err != nil { @@ -257,8 +283,12 @@ func (s *Box) start() error { return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]") } } - - s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)") + for serviceName, service := range s.postServices { + err = service.Start() + if err != nil { + return E.Cause(err, "start ", serviceName) + } + } return nil } @@ -270,6 +300,11 @@ func (s *Box) Close() error { close(s.done) } var errors error + for serviceName, service := range s.postServices { + errors = E.Append(errors, service.Close(), func(err error) error { + return E.Cause(err, "close ", serviceName) + }) + } for i, in := range s.inbounds { errors = E.Append(errors, in.Close(), func(err error) error { return E.Cause(err, "close inbound/", in.Type(), "[", i, "]") @@ -285,21 +320,16 @@ func (s *Box) Close() error { return E.Cause(err, "close router") }) } + for serviceName, service := range s.preServices { + errors = E.Append(errors, service.Close(), func(err error) error { + return E.Cause(err, "close ", serviceName) + }) + } if err := common.Close(s.logFactory); err != nil { errors = E.Append(errors, err, func(err error) error { return E.Cause(err, "close log factory") }) } - if err := common.Close(s.clashServer); err != nil { - errors = E.Append(errors, err, func(err error) error { - return E.Cause(err, "close clash api server") - }) - } - if err := common.Close(s.v2rayServer); err != nil { - errors = E.Append(errors, err, func(err error) error { - return E.Cause(err, "close v2ray api server") - }) - } if s.logFile != nil { errors = E.Append(errors, s.logFile.Close(), func(err error) error { return E.Cause(err, "close log file") diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go new file mode 100644 index 00000000..bc36b861 --- /dev/null +++ b/cmd/sing-box/cmd_tools.go @@ -0,0 +1,35 @@ +package main + +import ( + "context" + + box "github.com/sagernet/sing-box" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/spf13/cobra" +) + +var commandTools = &cobra.Command{ + Use: "tools", + Short: "experimental tools", +} + +func init() { + mainCommand.AddCommand(commandTools) +} + +func createPreStartedClient() (*box.Box, error) { + options, err := readConfigAndMerge() + if err != nil { + return nil, err + } + instance, err := box.New(context.Background(), options, nil) + if err != nil { + return nil, E.Cause(err, "create service") + } + err = instance.PreStart() + if err != nil { + return nil, E.Cause(err, "start service") + } + return instance, nil +} diff --git a/cmd/sing-box/cmd_tools_connect.go b/cmd/sing-box/cmd_tools_connect.go new file mode 100644 index 00000000..ebf0fb92 --- /dev/null +++ b/cmd/sing-box/cmd_tools_connect.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "os" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/task" + + "github.com/spf13/cobra" +) + +var commandFlagNetwork string + +var commandConnect = &cobra.Command{ + Use: "connect [address]", + Short: "connect to a address through default outbound", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + err := connect(args[0]) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandConnect.Flags().StringVar(&commandFlagNetwork, "network", "tcp", "network type") + commandTools.AddCommand(commandConnect) +} + +func connect(address string) error { + switch N.NetworkName(commandFlagNetwork) { + case N.NetworkTCP, N.NetworkUDP: + default: + return E.Cause(N.ErrUnknownNetwork, commandFlagNetwork) + } + instance, err := createPreStartedClient() + if err != nil { + return err + } + outbound := instance.Router().DefaultOutbound(commandFlagNetwork) + if outbound == nil { + return E.New("missing default outbound") + } + conn, err := outbound.DialContext(context.Background(), commandFlagNetwork, M.ParseSocksaddr(address)) + if err != nil { + return E.Cause(err, "connect to server") + } + var group task.Group + group.Append("upload", func(ctx context.Context) error { + return common.Error(bufio.Copy(conn, os.Stdin)) + }) + group.Append("download", func(ctx context.Context) error { + return common.Error(bufio.Copy(os.Stdout, conn)) + }) + err = group.Run(context.Background()) + if E.IsClosed(err) { + log.Info(err) + } else { + log.Error(err) + } + return nil +} diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index a7bd1b95..981de8c4 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -114,7 +114,7 @@ func NewServer(router adapter.Router, logFactory log.ObservableFactory, options return server, nil } -func (s *Server) Start() error { +func (s *Server) PreStart() error { if s.cacheFilePath != "" { cacheFile, err := cachefile.Open(s.cacheFilePath) if err != nil { @@ -122,6 +122,10 @@ func (s *Server) Start() error { } s.cacheFile = cacheFile } + return nil +} + +func (s *Server) Start() error { listener, err := net.Listen("tcp", s.httpServer.Addr) if err != nil { return E.Cause(err, "external controller listen error")