sing-box/common/tls/std_server.go

260 lines
6.1 KiB
Go
Raw Permalink Normal View History

2022-09-09 10:45:10 +00:00
package tls
2022-07-25 00:14:09 +00:00
import (
"context"
2022-07-25 00:14:09 +00:00
"crypto/tls"
2022-09-09 10:45:10 +00:00
"net"
2022-07-25 00:14:09 +00:00
"os"
2022-07-30 14:00:04 +00:00
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
2022-07-25 00:14:09 +00:00
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
2022-07-25 00:14:09 +00:00
E "github.com/sagernet/sing/common/exceptions"
2022-07-30 14:00:04 +00:00
"github.com/fsnotify/fsnotify"
2022-07-25 00:14:09 +00:00
)
var errInsecureUnused = E.New("tls: insecure unused")
2022-09-09 10:45:10 +00:00
type STDServerConfig struct {
2022-07-30 14:00:04 +00:00
config *tls.Config
logger log.Logger
acmeService adapter.Service
2022-07-30 14:00:04 +00:00
certificate []byte
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
}
func (c *STDServerConfig) ServerName() string {
return c.config.ServerName
}
func (c *STDServerConfig) SetServerName(serverName string) {
c.config.ServerName = serverName
}
func (c *STDServerConfig) NextProtos() []string {
return c.config.NextProtos
}
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) Config() (*STDConfig, error) {
return c.config, nil
}
func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) {
return tls.Client(conn, c.config), nil
2022-07-30 14:00:04 +00:00
}
func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
return tls.Server(conn, c.config), nil
2022-09-09 10:45:10 +00:00
}
func (c *STDServerConfig) Clone() Config {
return &STDServerConfig{
config: c.config.Clone(),
}
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) Start() error {
if c.acmeService != nil {
return c.acmeService.Start()
} else {
if c.certificatePath == "" && c.keyPath == "" {
return nil
}
err := c.startWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
2022-07-30 14:00:04 +00:00
return nil
}
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) startWatcher() error {
2022-07-30 14:00:04 +00:00
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
}
c.watcher = watcher
go c.loopUpdate()
return nil
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) loopUpdate() {
2022-07-30 14:00:04 +00:00
for {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) reloadKeyPair() error {
2022-07-30 14:00:04 +00:00
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
}
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
}
c.key = key
}
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []tls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}
2022-09-09 10:45:10 +00:00
func (c *STDServerConfig) Close() error {
if c.acmeService != nil {
return c.acmeService.Close()
}
2022-07-30 14:00:04 +00:00
if c.watcher != nil {
return c.watcher.Close()
}
return nil
}
2023-02-21 06:53:00 +00:00
func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
if !options.Enabled {
return nil, nil
}
var tlsConfig *tls.Config
var acmeService adapter.Service
var err error
if options.ACME != nil && len(options.ACME.Domain) > 0 {
tlsConfig, acmeService, err = startACME(ctx, common.PtrValueOrDefault(options.ACME))
//nolint:staticcheck
if err != nil {
return nil, err
}
if options.Insecure {
return nil, errInsecureUnused
}
} else {
tlsConfig = &tls.Config{}
}
2023-02-21 06:53:00 +00:00
tlsConfig.Time = router.TimeFunc()
if options.ServerName != "" {
tlsConfig.ServerName = options.ServerName
}
if len(options.ALPN) > 0 {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, options.ALPN...)
}
if options.MinVersion != "" {
minVersion, err := ParseTLSVersion(options.MinVersion)
if err != nil {
return nil, E.Cause(err, "parse min_version")
}
tlsConfig.MinVersion = minVersion
}
if options.MaxVersion != "" {
maxVersion, err := ParseTLSVersion(options.MaxVersion)
if err != nil {
return nil, E.Cause(err, "parse max_version")
}
tlsConfig.MaxVersion = maxVersion
}
if options.CipherSuites != nil {
find:
for _, cipherSuite := range options.CipherSuites {
for _, tlsCipherSuite := range tls.CipherSuites() {
if cipherSuite == tlsCipherSuite.Name {
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
continue find
}
}
return nil, E.New("unknown cipher_suite: ", cipherSuite)
}
}
var certificate []byte
var key []byte
if acmeService == nil {
if options.Certificate != "" {
certificate = []byte(options.Certificate)
} else if options.CertificatePath != "" {
content, err := os.ReadFile(options.CertificatePath)
if err != nil {
return nil, E.Cause(err, "read certificate")
}
certificate = content
}
if options.Key != "" {
key = []byte(options.Key)
} else if options.KeyPath != "" {
content, err := os.ReadFile(options.KeyPath)
if err != nil {
return nil, E.Cause(err, "read key")
}
key = content
}
if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
2023-02-21 06:53:00 +00:00
return GenerateKeyPair(router.TimeFunc(), info.ServerName)
}
} else {
if certificate == nil {
return nil, E.New("missing certificate")
} else if key == nil {
return nil, E.New("missing key")
}
keyPair, err := tls.X509KeyPair(certificate, key)
if err != nil {
return nil, E.Cause(err, "parse x509 key pair")
}
tlsConfig.Certificates = []tls.Certificate{keyPair}
}
}
return &STDServerConfig{
config: tlsConfig,
logger: logger,
acmeService: acmeService,
certificate: certificate,
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
}, nil
}