Add command to connect an address

This commit is contained in:
世界 2023-03-18 20:26:58 +08:00
parent c7f89ad88e
commit e5f3bb6344
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 204 additions and 50 deletions

View file

@ -10,6 +10,7 @@ import (
type ClashServer interface { type ClashServer interface {
Service Service
PreStarter
Mode() string Mode() string
StoreSelected() bool StoreSelected() bool
CacheFile() ClashCacheFile CacheFile() ClashCacheFile

15
adapter/prestart.go Normal file
View file

@ -0,0 +1,15 @@
package adapter
type PreStarter interface {
PreStart() error
}
func PreStart(starter any) error {
if preService, ok := starter.(PreStarter); ok {
err := preService.PreStart()
if err != nil {
return err
}
}
return nil
}

96
box.go
View file

@ -32,8 +32,8 @@ type Box struct {
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
logFile *os.File logFile *os.File
clashServer adapter.ClashServer preServices map[string]adapter.Service
v2rayServer adapter.V2RayServer postServices map[string]adapter.Service
done chan struct{} done chan struct{}
} }
@ -166,22 +166,23 @@ func New(ctx context.Context, options option.Options, platformInterface platform
if err != nil { if err != nil {
return nil, err return nil, err
} }
preServices := make(map[string]adapter.Service)
var clashServer adapter.ClashServer postServices := make(map[string]adapter.Service)
var v2rayServer adapter.V2RayServer
if needClashAPI { if needClashAPI {
clashServer, err = experimental.NewClashServer(router, observableLogFactory, common.PtrValueOrDefault(options.Experimental.ClashAPI)) clashServer, err := experimental.NewClashServer(router, observableLogFactory, common.PtrValueOrDefault(options.Experimental.ClashAPI))
if err != nil { if err != nil {
return nil, E.Cause(err, "create clash api server") return nil, E.Cause(err, "create clash api server")
} }
router.SetClashServer(clashServer) router.SetClashServer(clashServer)
preServices["clash api"] = clashServer
} }
if needV2RayAPI { if needV2RayAPI {
v2rayServer, err = experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(options.Experimental.V2RayAPI)) v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(options.Experimental.V2RayAPI))
if err != nil { if err != nil {
return nil, E.Cause(err, "create v2ray api server") return nil, E.Cause(err, "create v2ray api server")
} }
router.SetV2RayServer(v2rayServer) router.SetV2RayServer(v2rayServer)
preServices["v2ray api"] = v2rayServer
} }
return &Box{ return &Box{
router: router, router: router,
@ -191,12 +192,31 @@ func New(ctx context.Context, options option.Options, platformInterface platform
logFactory: logFactory, logFactory: logFactory,
logger: logFactory.Logger(), logger: logFactory.Logger(),
logFile: logFile, logFile: logFile,
clashServer: clashServer, preServices: preServices,
v2rayServer: v2rayServer, postServices: postServices,
done: make(chan struct{}), done: make(chan struct{}),
}, nil }, nil
} }
func (s *Box) PreStart() error {
err := s.preStart()
if err != nil {
// TODO: remove catch error
defer func() {
v := recover()
if v != nil {
log.Error(E.Cause(err, "origin error"))
debug.PrintStack()
panic("panic on early close: " + fmt.Sprint(v))
}
}()
s.Close()
return err
}
s.logger.Info("sing-box pre-started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)")
return nil
}
func (s *Box) Start() error { func (s *Box) Start() error {
err := s.start() err := s.start()
if err != nil { if err != nil {
@ -210,21 +230,17 @@ func (s *Box) Start() error {
} }
}() }()
s.Close() s.Close()
}
return err return err
}
s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)")
return nil
} }
func (s *Box) start() error { func (s *Box) preStart() error {
if s.clashServer != nil { for serviceName, service := range s.preServices {
err := s.clashServer.Start() err := adapter.PreStart(service)
if err != nil { if err != nil {
return E.Cause(err, "start clash api server") return E.Cause(err, "pre-start ", serviceName)
}
}
if s.v2rayServer != nil {
err := s.v2rayServer.Start()
if err != nil {
return E.Cause(err, "start v2ray api server")
} }
} }
for i, out := range s.outbounds { for i, out := range s.outbounds {
@ -241,10 +257,20 @@ func (s *Box) start() error {
} }
} }
} }
err := s.router.Start() return s.router.Start()
}
func (s *Box) start() error {
err := s.preStart()
if err != nil { if err != nil {
return err return err
} }
for serviceName, service := range s.preServices {
err = service.Start()
if err != nil {
return E.Cause(err, "start ", serviceName)
}
}
for i, in := range s.inbounds { for i, in := range s.inbounds {
err = in.Start() err = in.Start()
if err != nil { if err != nil {
@ -257,8 +283,12 @@ func (s *Box) start() error {
return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]") return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]")
} }
} }
for serviceName, service := range s.postServices {
s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)") err = service.Start()
if err != nil {
return E.Cause(err, "start ", serviceName)
}
}
return nil return nil
} }
@ -270,6 +300,11 @@ func (s *Box) Close() error {
close(s.done) close(s.done)
} }
var errors error var errors error
for serviceName, service := range s.postServices {
errors = E.Append(errors, service.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName)
})
}
for i, in := range s.inbounds { for i, in := range s.inbounds {
errors = E.Append(errors, in.Close(), func(err error) error { errors = E.Append(errors, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", i, "]") return E.Cause(err, "close inbound/", in.Type(), "[", i, "]")
@ -285,21 +320,16 @@ func (s *Box) Close() error {
return E.Cause(err, "close router") return E.Cause(err, "close router")
}) })
} }
for serviceName, service := range s.preServices {
errors = E.Append(errors, service.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName)
})
}
if err := common.Close(s.logFactory); err != nil { if err := common.Close(s.logFactory); err != nil {
errors = E.Append(errors, err, func(err error) error { errors = E.Append(errors, err, func(err error) error {
return E.Cause(err, "close log factory") return E.Cause(err, "close log factory")
}) })
} }
if err := common.Close(s.clashServer); err != nil {
errors = E.Append(errors, err, func(err error) error {
return E.Cause(err, "close clash api server")
})
}
if err := common.Close(s.v2rayServer); err != nil {
errors = E.Append(errors, err, func(err error) error {
return E.Cause(err, "close v2ray api server")
})
}
if s.logFile != nil { if s.logFile != nil {
errors = E.Append(errors, s.logFile.Close(), func(err error) error { errors = E.Append(errors, s.logFile.Close(), func(err error) error {
return E.Cause(err, "close log file") return E.Cause(err, "close log file")

35
cmd/sing-box/cmd_tools.go Normal file
View file

@ -0,0 +1,35 @@
package main
import (
"context"
box "github.com/sagernet/sing-box"
E "github.com/sagernet/sing/common/exceptions"
"github.com/spf13/cobra"
)
var commandTools = &cobra.Command{
Use: "tools",
Short: "experimental tools",
}
func init() {
mainCommand.AddCommand(commandTools)
}
func createPreStartedClient() (*box.Box, error) {
options, err := readConfigAndMerge()
if err != nil {
return nil, err
}
instance, err := box.New(context.Background(), options, nil)
if err != nil {
return nil, E.Cause(err, "create service")
}
err = instance.PreStart()
if err != nil {
return nil, E.Cause(err, "start service")
}
return instance, nil
}

View file

@ -0,0 +1,69 @@
package main
import (
"context"
"os"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
"github.com/spf13/cobra"
)
var commandFlagNetwork string
var commandConnect = &cobra.Command{
Use: "connect [address]",
Short: "connect to a address through default outbound",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
err := connect(args[0])
if err != nil {
log.Fatal(err)
}
},
}
func init() {
commandConnect.Flags().StringVar(&commandFlagNetwork, "network", "tcp", "network type")
commandTools.AddCommand(commandConnect)
}
func connect(address string) error {
switch N.NetworkName(commandFlagNetwork) {
case N.NetworkTCP, N.NetworkUDP:
default:
return E.Cause(N.ErrUnknownNetwork, commandFlagNetwork)
}
instance, err := createPreStartedClient()
if err != nil {
return err
}
outbound := instance.Router().DefaultOutbound(commandFlagNetwork)
if outbound == nil {
return E.New("missing default outbound")
}
conn, err := outbound.DialContext(context.Background(), commandFlagNetwork, M.ParseSocksaddr(address))
if err != nil {
return E.Cause(err, "connect to server")
}
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(bufio.Copy(conn, os.Stdin))
})
group.Append("download", func(ctx context.Context) error {
return common.Error(bufio.Copy(os.Stdout, conn))
})
err = group.Run(context.Background())
if E.IsClosed(err) {
log.Info(err)
} else {
log.Error(err)
}
return nil
}

View file

@ -114,7 +114,7 @@ func NewServer(router adapter.Router, logFactory log.ObservableFactory, options
return server, nil return server, nil
} }
func (s *Server) Start() error { func (s *Server) PreStart() error {
if s.cacheFilePath != "" { if s.cacheFilePath != "" {
cacheFile, err := cachefile.Open(s.cacheFilePath) cacheFile, err := cachefile.Open(s.cacheFilePath)
if err != nil { if err != nil {
@ -122,6 +122,10 @@ func (s *Server) Start() error {
} }
s.cacheFile = cacheFile s.cacheFile = cacheFile
} }
return nil
}
func (s *Server) Start() error {
listener, err := net.Listen("tcp", s.httpServer.Addr) listener, err := net.Listen("tcp", s.httpServer.Addr)
if err != nil { if err != nil {
return E.Cause(err, "external controller listen error") return E.Cause(err, "external controller listen error")