From 70b4577dbea5f51a97ac64e32050669fa0b35079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 30 Jul 2022 22:00:04 +0800 Subject: [PATCH] Add TLS certificate reload --- common/process/searcher_android.go | 30 ++++---- docs/configuration/shared/tls.md | 6 +- go.mod | 2 +- go.sum | 4 +- inbound/http.go | 24 +++++- inbound/tls.go | 118 ++++++++++++++++++++++++++++- inbound/vmess.go | 25 +++++- test/go.mod | 2 +- test/go.sum | 4 +- 9 files changed, 186 insertions(+), 29 deletions(-) diff --git a/common/process/searcher_android.go b/common/process/searcher_android.go index 2da6413c..5e070209 100644 --- a/common/process/searcher_android.go +++ b/common/process/searcher_android.go @@ -36,7 +36,7 @@ func (s *androidSearcher) Start() error { } err = s.startWatcher() if err != nil { - s.logger.Debug("create fsnotify watcher: ", err) + s.logger.Warn("create fsnotify watcher: ", err) } return nil } @@ -56,20 +56,22 @@ func (s *androidSearcher) startWatcher() error { } func (s *androidSearcher) loopUpdate() { - select { - case _, ok := <-s.watcher.Events: - if !ok { - return + for { + select { + case _, ok := <-s.watcher.Events: + if !ok { + return + } + err := s.updatePackages() + if err != nil { + s.logger.Error(E.Cause(err, "update packages list")) + } + case err, ok := <-s.watcher.Errors: + if !ok { + return + } + s.logger.Error(E.Cause(err, "fsnotify error")) } - err := s.updatePackages() - if err != nil { - s.logger.Error(E.Cause(err, "update packages list")) - } - case err, ok := <-s.watcher.Errors: - if !ok { - return - } - s.logger.Error(E.Cause(err, "fsnotify error")) } } diff --git a/docs/configuration/shared/tls.md b/docs/configuration/shared/tls.md index e0072e0a..2ca75939 100644 --- a/docs/configuration/shared/tls.md +++ b/docs/configuration/shared/tls.md @@ -133,4 +133,8 @@ The server private key, in PEM format. ==Server only== -The path to the server private key, in PEM format. \ No newline at end of file +The path to the server private key, in PEM format. + +### Reload + +For server configuration, certificate and key will be automatically reloaded if modified. \ No newline at end of file diff --git a/go.mod b/go.mod index edc35a6f..c3f1fe78 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( go.uber.org/atomic v1.9.0 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/net v0.0.0-20220728211354-c7608f3a8462 - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 + golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 ) require ( diff --git a/go.sum b/go.sum index 9013422c..e93f62fd 100644 --- a/go.sum +++ b/go.sum @@ -279,8 +279,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 h1:aNCnH+Fiqs7ZDTFH6oEFjIfbX2HvgQXJ6uQuUbTobjk= +golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/inbound/http.go b/inbound/http.go index c54f3cf0..ed27c1e6 100644 --- a/inbound/http.go +++ b/inbound/http.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/protocol/http" @@ -22,7 +23,7 @@ var _ adapter.Inbound = (*HTTP)(nil) type HTTP struct { myInboundAdapter authenticator auth.Authenticator - tlsConfig *tls.Config + tlsConfig *TLSConfig } func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) (*HTTP, error) { @@ -40,7 +41,7 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge authenticator: auth.NewAuthenticator(options.Users), } if options.TLS != nil { - tlsConfig, err := NewTLSConfig(common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } @@ -50,9 +51,26 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge return inbound, nil } +func (h *HTTP) Start() error { + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return E.Cause(err, "create TLS config") + } + } + return h.myInboundAdapter.Start() +} + +func (h *HTTP) Close() error { + return common.Close( + &h.myInboundAdapter, + common.PtrOrNil(h.tlsConfig), + ) +} + func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if h.tlsConfig != nil { - conn = tls.Server(conn, h.tlsConfig) + conn = tls.Server(conn, h.tlsConfig.Config()) } return http.HandleConnection(ctx, conn, std_bufio.NewReader(conn), h.authenticator, h.upstreamUserHandler(metadata), M.Metadata{}) } diff --git a/inbound/tls.go b/inbound/tls.go index 7009c25a..811bae9c 100644 --- a/inbound/tls.go +++ b/inbound/tls.go @@ -4,11 +4,118 @@ import ( "crypto/tls" "os" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + + "github.com/fsnotify/fsnotify" ) -func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) { +var _ adapter.Service = (*TLSConfig)(nil) + +type TLSConfig struct { + config *tls.Config + logger log.Logger + certificate []byte + key []byte + certificatePath string + keyPath string + watcher *fsnotify.Watcher +} + +func (c *TLSConfig) Config() *tls.Config { + return c.config +} + +func (c *TLSConfig) Start() error { + if c.certificatePath == "" && c.keyPath == "" { + return nil + } + err := c.startWatcher() + if err != nil { + c.logger.Warn("create fsnotify watcher: ", err) + } + return nil +} + +func (c *TLSConfig) startWatcher() error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + if c.certificatePath != "" { + err = watcher.Add(c.certificatePath) + if err != nil { + return err + } + } + if c.keyPath != "" { + err = watcher.Add(c.keyPath) + if err != nil { + return err + } + } + c.watcher = watcher + go c.loopUpdate() + return nil +} + +func (c *TLSConfig) loopUpdate() { + for { + select { + case event, ok := <-c.watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write != fsnotify.Write { + continue + } + err := c.reloadKeyPair() + if err != nil { + c.logger.Error(E.Cause(err, "reload TLS key pair")) + } + case err, ok := <-c.watcher.Errors: + if !ok { + return + } + c.logger.Error(E.Cause(err, "fsnotify error")) + } + } +} + +func (c *TLSConfig) reloadKeyPair() error { + if c.certificatePath != "" { + certificate, err := os.ReadFile(c.certificatePath) + if err != nil { + return E.Cause(err, "reload certificate from ", c.certificatePath) + } + c.certificate = certificate + } + if c.keyPath != "" { + key, err := os.ReadFile(c.keyPath) + if err != nil { + return E.Cause(err, "reload key from ", c.keyPath) + } + c.key = key + } + keyPair, err := tls.X509KeyPair(c.certificate, c.key) + if err != nil { + return E.Cause(err, "reload key pair") + } + c.config.Certificates = []tls.Certificate{keyPair} + c.logger.Info("reloaded TLS certificate") + return nil +} + +func (c *TLSConfig) Close() error { + if c.watcher != nil { + return c.watcher.Close() + } + return nil +} + +func NewTLSConfig(logger log.Logger, options option.InboundTLSOptions) (*TLSConfig, error) { if !options.Enabled { return nil, nil } @@ -76,5 +183,12 @@ func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) { return nil, E.Cause(err, "parse x509 key pair") } tlsConfig.Certificates = []tls.Certificate{keyPair} - return &tlsConfig, nil + return &TLSConfig{ + config: &tlsConfig, + logger: logger, + certificate: certificate, + key: key, + certificatePath: options.CertificatePath, + keyPath: options.KeyPath, + }, nil } diff --git a/inbound/vmess.go b/inbound/vmess.go index 9651ccbf..88b19495 100644 --- a/inbound/vmess.go +++ b/inbound/vmess.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-vmess" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" N "github.com/sagernet/sing/common/network" ) @@ -23,7 +24,7 @@ type VMess struct { myInboundAdapter service *vmess.Service[int] users []option.VMessUser - tlsConfig *tls.Config + tlsConfig *TLSConfig } func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessInboundOptions) (*VMess, error) { @@ -49,19 +50,37 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg return nil, err } if options.TLS != nil { - inbound.tlsConfig, err = NewTLSConfig(common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } + inbound.tlsConfig = tlsConfig } inbound.service = service inbound.connHandler = inbound return inbound, nil } +func (h *VMess) Start() error { + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return E.Cause(err, "create TLS config") + } + } + return h.myInboundAdapter.Start() +} + +func (h *VMess) Close() error { + return common.Close( + &h.myInboundAdapter, + common.PtrOrNil(h.tlsConfig), + ) +} + func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if h.tlsConfig != nil { - conn = tls.Server(conn, h.tlsConfig) + conn = tls.Server(conn, h.tlsConfig.Config()) } return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata)) } diff --git a/test/go.mod b/test/go.mod index cd3fda12..e43e2d03 100644 --- a/test/go.mod +++ b/test/go.mod @@ -61,7 +61,7 @@ require ( go.uber.org/atomic v1.9.0 // indirect golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect golang.org/x/mod v0.5.1 // indirect - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect + golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.1.9 // indirect diff --git a/test/go.sum b/test/go.sum index 455c4996..8aba00a1 100644 --- a/test/go.sum +++ b/test/go.sum @@ -314,8 +314,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 h1:aNCnH+Fiqs7ZDTFH6oEFjIfbX2HvgQXJ6uQuUbTobjk= +golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=