Add TLS certificate reload

This commit is contained in:
世界 2022-07-30 22:00:04 +08:00
parent 5f1f55fbe7
commit 70b4577dbe
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
9 changed files with 186 additions and 29 deletions

View file

@ -36,7 +36,7 @@ func (s *androidSearcher) Start() error {
}
err = s.startWatcher()
if err != nil {
s.logger.Debug("create fsnotify watcher: ", err)
s.logger.Warn("create fsnotify watcher: ", err)
}
return nil
}
@ -56,6 +56,7 @@ func (s *androidSearcher) startWatcher() error {
}
func (s *androidSearcher) loopUpdate() {
for {
select {
case _, ok := <-s.watcher.Events:
if !ok {
@ -72,6 +73,7 @@ func (s *androidSearcher) loopUpdate() {
s.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (s *androidSearcher) Close() error {
return common.Close(common.PtrOrNil(s.watcher))

View file

@ -134,3 +134,7 @@ The server private key, in PEM format.
==Server only==
The path to the server private key, in PEM format.
### Reload
For server configuration, certificate and key will be automatically reloaded if modified.

2
go.mod
View file

@ -23,7 +23,7 @@ require (
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/net v0.0.0-20220728211354-c7608f3a8462
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39
)
require (

4
go.sum
View file

@ -279,8 +279,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 h1:aNCnH+Fiqs7ZDTFH6oEFjIfbX2HvgQXJ6uQuUbTobjk=
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -12,6 +12,7 @@ import (
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/http"
@ -22,7 +23,7 @@ var _ adapter.Inbound = (*HTTP)(nil)
type HTTP struct {
myInboundAdapter
authenticator auth.Authenticator
tlsConfig *tls.Config
tlsConfig *TLSConfig
}
func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) (*HTTP, error) {
@ -40,7 +41,7 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge
authenticator: auth.NewAuthenticator(options.Users),
}
if options.TLS != nil {
tlsConfig, err := NewTLSConfig(common.PtrValueOrDefault(options.TLS))
tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
@ -50,9 +51,26 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge
return inbound, nil
}
func (h *HTTP) Start() error {
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
return h.myInboundAdapter.Start()
}
func (h *HTTP) Close() error {
return common.Close(
&h.myInboundAdapter,
common.PtrOrNil(h.tlsConfig),
)
}
func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if h.tlsConfig != nil {
conn = tls.Server(conn, h.tlsConfig)
conn = tls.Server(conn, h.tlsConfig.Config())
}
return http.HandleConnection(ctx, conn, std_bufio.NewReader(conn), h.authenticator, h.upstreamUserHandler(metadata), M.Metadata{})
}

View file

@ -4,11 +4,118 @@ import (
"crypto/tls"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/fsnotify/fsnotify"
)
func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) {
var _ adapter.Service = (*TLSConfig)(nil)
type TLSConfig struct {
config *tls.Config
logger log.Logger
certificate []byte
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
}
func (c *TLSConfig) Config() *tls.Config {
return c.config
}
func (c *TLSConfig) Start() error {
if c.certificatePath == "" && c.keyPath == "" {
return nil
}
err := c.startWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
return nil
}
func (c *TLSConfig) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
}
c.watcher = watcher
go c.loopUpdate()
return nil
}
func (c *TLSConfig) loopUpdate() {
for {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (c *TLSConfig) reloadKeyPair() error {
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
}
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
}
c.key = key
}
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []tls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}
func (c *TLSConfig) Close() error {
if c.watcher != nil {
return c.watcher.Close()
}
return nil
}
func NewTLSConfig(logger log.Logger, options option.InboundTLSOptions) (*TLSConfig, error) {
if !options.Enabled {
return nil, nil
}
@ -76,5 +183,12 @@ func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) {
return nil, E.Cause(err, "parse x509 key pair")
}
tlsConfig.Certificates = []tls.Certificate{keyPair}
return &tlsConfig, nil
return &TLSConfig{
config: &tlsConfig,
logger: logger,
certificate: certificate,
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
}, nil
}

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-vmess"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
N "github.com/sagernet/sing/common/network"
)
@ -23,7 +24,7 @@ type VMess struct {
myInboundAdapter
service *vmess.Service[int]
users []option.VMessUser
tlsConfig *tls.Config
tlsConfig *TLSConfig
}
func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessInboundOptions) (*VMess, error) {
@ -49,19 +50,37 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg
return nil, err
}
if options.TLS != nil {
inbound.tlsConfig, err = NewTLSConfig(common.PtrValueOrDefault(options.TLS))
tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
inbound.tlsConfig = tlsConfig
}
inbound.service = service
inbound.connHandler = inbound
return inbound, nil
}
func (h *VMess) Start() error {
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
return h.myInboundAdapter.Start()
}
func (h *VMess) Close() error {
return common.Close(
&h.myInboundAdapter,
common.PtrOrNil(h.tlsConfig),
)
}
func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if h.tlsConfig != nil {
conn = tls.Server(conn, h.tlsConfig)
conn = tls.Server(conn, h.tlsConfig.Config())
}
return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}

View file

@ -61,7 +61,7 @@ require (
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
golang.org/x/mod v0.5.1 // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
golang.org/x/tools v0.1.9 // indirect

View file

@ -314,8 +314,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 h1:aNCnH+Fiqs7ZDTFH6oEFjIfbX2HvgQXJ6uQuUbTobjk=
golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=