sing-box/common/tls/std_server.go

244 lines
5.7 KiB
Go
Raw Normal View History

2022-09-09 10:45:10 +00:00
package tls
2022-07-25 00:14:09 +00:00
import (
"context"
2022-07-25 00:14:09 +00:00
"crypto/tls"
2022-09-09 10:45:10 +00:00
"net"
2022-07-25 00:14:09 +00:00
"os"
2022-07-30 14:00:04 +00:00
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
2022-07-25 00:14:09 +00:00
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
2022-07-25 00:14:09 +00:00
E "github.com/sagernet/sing/common/exceptions"
2022-07-30 14:00:04 +00:00
"github.com/fsnotify/fsnotify"
2022-07-25 00:14:09 +00:00
)
2022-09-09 10:45:10 +00:00
type STDServerConfig struct {
2022-07-30 14:00:04 +00:00
config *tls.Config
logger log.Logger
acmeService adapter.Service
2022-07-30 14:00:04 +00:00
certificate []byte
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
}
func (c *STDServerConfig) NextProtos() []string {
return c.config.NextProtos
}
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto
}
2022-09-23 09:13:18 +00:00
var errInsecureUnused = E.New("tls: insecure unused")
2022-09-10 02:27:00 +00:00
func newSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
2022-09-09 10:45:10 +00:00
if !options.Enabled {
return nil, nil
}
var tlsConfig *tls.Config
var acmeService adapter.Service
var err error
if options.ACME != nil && len(options.ACME.Domain) > 0 {
tlsConfig, acmeService, err = startACME(ctx, common.PtrValueOrDefault(options.ACME))
if err != nil {
return nil, err
}
2022-09-23 09:13:18 +00:00
if options.Insecure {
return nil, errInsecureUnused
}
2022-09-09 10:45:10 +00:00
} else {
tlsConfig = &tls.Config{}
}
if options.ServerName != "" {
tlsConfig.ServerName = options.ServerName
}
if len(options.ALPN) > 0 {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, options.ALPN...)
}
if options.MinVersion != "" {
minVersion, err := ParseTLSVersion(options.MinVersion)
if err != nil {
return nil, E.Cause(err, "parse min_version")
}
tlsConfig.MinVersion = minVersion
}
if options.MaxVersion != "" {
maxVersion, err := ParseTLSVersion(options.MaxVersion)
if err != nil {
return nil, E.Cause(err, "parse max_version")
}
tlsConfig.MaxVersion = maxVersion
}
if options.CipherSuites != nil {
find:
for _, cipherSuite := range options.CipherSuites {
for _, tlsCipherSuite := range tls.CipherSuites() {
if cipherSuite == tlsCipherSuite.Name {
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
continue find
}
}
return nil, E.New("unknown cipher_suite: ", cipherSuite)
}
}
var certificate []byte
var key []byte
if acmeService == nil {
if options.Certificate != "" {
certificate = []byte(options.Certificate)
} else if options.CertificatePath != "" {
content, err := os.ReadFile(options.CertificatePath)
if err != nil {
return nil, E.Cause(err, "read certificate")
}
certificate = content
}
if options.Key != "" {
key = []byte(options.Key)
} else if options.KeyPath != "" {
content, err := os.ReadFile(options.KeyPath)
if err != nil {
return nil, E.Cause(err, "read key")
}
key = content
}
2022-09-23 09:13:18 +00:00
if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateKeyPair(info.ServerName)
}
} else {
if certificate == nil {
return nil, E.New("missing certificate")
} else if key == nil {
return nil, E.New("missing key")
}
keyPair, err := tls.X509KeyPair(certificate, key)
if err != nil {
return nil, E.Cause(err, "parse x509 key pair")
}
tlsConfig.Certificates = []tls.Certificate{keyPair}
2022-09-09 10:45:10 +00:00
}
}
return &STDServerConfig{
config: tlsConfig,
logger: logger,
acmeService: acmeService,
certificate: certificate,
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
}, nil
}
func (c *STDServerConfig) Config() (*STDConfig, error) {
return c.config, nil
}
func (c *STDServerConfig) Client(conn net.Conn) Conn {
return tls.Client(conn, c.config)
2022-07-30 14:00:04 +00:00
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) Server(conn net.Conn) Conn {
return tls.Server(conn, c.config)
}
func (c *STDServerConfig) Start() error {
if c.acmeService != nil {
return c.acmeService.Start()
} else {
if c.certificatePath == "" && c.keyPath == "" {
return nil
}
err := c.startWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
2022-07-30 14:00:04 +00:00
return nil
}
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) startWatcher() error {
2022-07-30 14:00:04 +00:00
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
}
c.watcher = watcher
go c.loopUpdate()
return nil
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) loopUpdate() {
2022-07-30 14:00:04 +00:00
for {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) reloadKeyPair() error {
2022-07-30 14:00:04 +00:00
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
}
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
}
c.key = key
}
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []tls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) Close() error {
if c.acmeService != nil {
return c.acmeService.Close()
}
2022-07-30 14:00:04 +00:00
if c.watcher != nil {
return c.watcher.Close()
}
return nil
}