Add mitm service

This commit is contained in:
世界 2023-04-02 15:27:13 +08:00
parent 0840fb1a28
commit 794506c42d
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
14 changed files with 392 additions and 21 deletions

18
adapter/mitm.go Normal file
View File

@ -0,0 +1,18 @@
package adapter
import (
"context"
"crypto/tls"
"net"
N "github.com/sagernet/sing/common/network"
)
type MITMService interface {
Service
ProcessConnection(ctx context.Context, conn net.Conn, dialer N.Dialer, metadata InboundContext) (net.Conn, error)
}
type TLSOutbound interface {
NewTLSConnection(ctx context.Context, conn net.Conn, tlsConfig *tls.Config, metadata InboundContext) error
}

View File

@ -45,11 +45,14 @@ type Router interface {
InterfaceMonitor() tun.DefaultInterfaceMonitor
PackageManager() tun.PackageManager
Rules() []Rule
Rules() []RouteRule
IPRules() []IPRule
TimeService
MITMService() MITMService
SetMITMService(service MITMService)
ClashServer() ClashServer
SetClashServer(server ClashServer)
@ -80,6 +83,11 @@ type Rule interface {
String() string
}
type RouteRule interface {
Rule
MITM() bool
}
type DNSRule interface {
Rule
DisableCache() bool

31
box.go
View File

@ -14,6 +14,7 @@ import (
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/inbound"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/mitm"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/outbound"
"github.com/sagernet/sing-box/route"
@ -184,6 +185,14 @@ func New(ctx context.Context, options option.Options, platformInterface platform
router.SetV2RayServer(v2rayServer)
preServices["v2ray api"] = v2rayServer
}
if options.MITM != nil && options.MITM.Enabled {
mitmService, err := mitm.NewService(router, logFactory.NewLogger("mitm"), common.PtrValueOrDefault(options.MITM))
if err != nil {
return nil, E.Cause(err, "create mitm service")
}
postServices["mitm"] = mitmService
router.SetMITMService(mitmService)
}
return &Box{
router: router,
inbounds: inbounds,
@ -271,6 +280,12 @@ func (s *Box) start() error {
return E.Cause(err, "start ", serviceName)
}
}
for serviceName, service := range s.postServices {
err = service.Start()
if err != nil {
return E.Cause(err, "start ", serviceName)
}
}
for i, in := range s.inbounds {
err = in.Start()
if err != nil {
@ -283,12 +298,6 @@ func (s *Box) start() error {
return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]")
}
}
for serviceName, service := range s.postServices {
err = service.Start()
if err != nil {
return E.Cause(err, "start ", serviceName)
}
}
return nil
}
@ -300,16 +309,16 @@ func (s *Box) Close() error {
close(s.done)
}
var errors error
for serviceName, service := range s.postServices {
errors = E.Append(errors, service.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName)
})
}
for i, in := range s.inbounds {
errors = E.Append(errors, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", i, "]")
})
}
for serviceName, service := range s.postServices {
errors = E.Append(errors, service.Close(), func(err error) error {
return E.Cause(err, "close ", serviceName)
})
}
for i, out := range s.outbounds {
errors = E.Append(errors, common.Close(out), func(err error) error {
return E.Cause(err, "close inbound/", out.Type(), "[", i, "]")

View File

@ -34,3 +34,24 @@ func ParseTLSVersion(version string) (uint16, error) {
return 0, E.New("unknown tls version:", version)
}
}
func ConfigFromClientHello(clientHello *tls.ClientHelloInfo) *tls.Config {
minVersion := clientHello.SupportedVersions[0]
maxVersion := minVersion
for _, version := range clientHello.SupportedVersions {
if version > maxVersion {
maxVersion = version
}
if version < minVersion {
minVersion = version
}
}
return &tls.Config{
CipherSuites: clientHello.CipherSuites,
NextProtos: clientHello.SupportedProtos,
ServerName: clientHello.ServerName,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: clientHello.SupportedCurves,
}
}

View File

@ -1,6 +1,7 @@
package tls
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@ -11,7 +12,7 @@ import (
"time"
)
func GenerateKeyPair(timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
func GenerateKeyPair(timeFunc func() time.Time, serverName string, parent *tls.Certificate) (*tls.Certificate, error) {
if timeFunc == nil {
timeFunc = time.Now
}
@ -35,7 +36,24 @@ func GenerateKeyPair(timeFunc func() time.Time, serverName string) (*tls.Certifi
},
DNSNames: []string{serverName},
}
publicDer, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
var (
parentCertificate *x509.Certificate
parentKey crypto.PrivateKey
)
if parent != nil {
if parent.Leaf == nil {
parent.Leaf, err = x509.ParseCertificate(parent.Certificate[0])
if err != nil {
return nil, err
}
}
parentCertificate = parent.Leaf
parentKey = parent.PrivateKey
} else {
parentCertificate = template
parentKey = key
}
publicDer, err := x509.CreateCertificate(rand.Reader, template, parentCertificate, key.Public(), parentKey)
if err != nil {
return nil, err
}

View File

@ -231,7 +231,7 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
}
if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateKeyPair(router.TimeFunc(), info.ServerName)
return GenerateKeyPair(router.TimeFunc(), info.ServerName, nil)
}
} else {
if certificate == nil {

157
mitm/service.go Normal file
View File

@ -0,0 +1,157 @@
package mitm
import (
"context"
"crypto/tls"
"io"
"net"
"os"
"github.com/sagernet/sing-box/adapter"
sTLS "github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"github.com/fsnotify/fsnotify"
)
var _ adapter.MITMService = (*Service)(nil)
type Service struct {
router adapter.Router
logger logger.ContextLogger
tlsCertificate *tls.Certificate
certificate []byte
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
insecure bool
}
func NewService(router adapter.Router, logger logger.ContextLogger, options option.MITMServiceOptions) (*Service, error) {
var tlsCertificate *tls.Certificate
var certificate []byte
var key []byte
if options.Certificate != "" {
certificate = []byte(options.Certificate)
} else if options.CertificatePath != "" {
content, err := os.ReadFile(options.CertificatePath)
if err != nil {
return nil, E.Cause(err, "read certificate")
}
certificate = content
}
if options.Key != "" {
key = []byte(options.Key)
} else if options.KeyPath != "" {
content, err := os.ReadFile(options.KeyPath)
if err != nil {
return nil, E.Cause(err, "read key")
}
key = content
}
if certificate == nil && key != nil {
return nil, E.New("missing certificate")
} else if certificate != nil && key == nil {
return nil, E.New("missing key")
} else if certificate != nil && key != nil {
keyPair, err := tls.X509KeyPair(certificate, key)
if err != nil {
return nil, E.Cause(err, "parse x509 key pair")
}
tlsCertificate = &keyPair
}
service := &Service{
router: router,
logger: logger,
tlsCertificate: tlsCertificate,
certificate: certificate,
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
insecure: options.Insecure,
}
return service, nil
}
func (s *Service) ProcessConnection(ctx context.Context, conn net.Conn, dialer N.Dialer, metadata adapter.InboundContext) (net.Conn, error) {
buffer := buf.NewPacket()
buffer.FullReset()
var clientHello *tls.ClientHelloInfo
_ = tls.Server(bufio.NewReadOnlyConn(io.TeeReader(conn, buffer)), &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
clientHello = argHello
return nil, nil
},
}).HandshakeContext(ctx)
if clientHello == nil {
s.logger.DebugContext(ctx, "not a TLS connection")
return bufio.NewCachedConn(conn, buffer), nil
}
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, dialer, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
} else {
outConn, err = dialer.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
return nil, N.HandshakeFailure(conn, err)
}
tlsConfig := sTLS.ConfigFromClientHello(clientHello)
tlsConfig.InsecureSkipVerify = s.insecure
tlsConfig.Time = s.router.TimeFunc()
if tlsConfig.ServerName == "" {
tlsConfig.ServerName = metadata.Destination.AddrString()
}
serverConn := tls.Client(outConn, tlsConfig)
err = serverConn.HandshakeContext(ctx)
if err != nil {
return nil, N.HandshakeFailure(conn, err)
}
clientConn := tls.Server(bufio.NewCachedConn(conn, buffer), &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var serverConfig tls.Config
serverConfig.Time = s.router.TimeFunc()
if serverConn.ConnectionState().NegotiatedProtocol != "" {
serverConfig.NextProtos = []string{serverConn.ConnectionState().NegotiatedProtocol}
}
serverConfig.ServerName = clientHello.ServerName
serverConfig.MinVersion = tls.VersionTLS10
serverConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return sTLS.GenerateKeyPair(nil, serverConfig.ServerName, s.tlsCertificate)
}
return &serverConfig, nil
},
})
err = clientConn.HandshakeContext(ctx)
if err != nil {
return nil, E.Cause(err, "mitm TLS handshake")
}
s.logger.DebugContext(ctx, "mitm TLS handshake success")
return nil, bufio.CopyConn(ctx, clientConn, serverConn)
}
func (s *Service) Start() error {
if s.certificatePath != "" || s.keyPath != "" {
err := s.startWatcher()
if err != nil {
s.logger.Warn("create fsnotify watcher: ", err)
}
}
return nil
}
func (s *Service) Close() error {
return common.Close(common.PtrOrNil(s.watcher))
}

View File

@ -0,0 +1,79 @@
package mitm
import (
"crypto/tls"
"os"
E "github.com/sagernet/sing/common/exceptions"
"github.com/fsnotify/fsnotify"
)
func (s *Service) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
if s.certificatePath != "" {
err = watcher.Add(s.certificatePath)
if err != nil {
return err
}
}
if s.keyPath != "" {
err = watcher.Add(s.keyPath)
if err != nil {
return err
}
}
s.watcher = watcher
go s.loopUpdate()
return nil
}
func (s *Service) loopUpdate() {
for {
select {
case event, ok := <-s.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := s.reloadKeyPair()
if err != nil {
s.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-s.watcher.Errors:
if !ok {
return
}
s.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (s *Service) reloadKeyPair() error {
if s.certificatePath != "" {
certificate, err := os.ReadFile(s.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", s.certificatePath)
}
s.certificate = certificate
}
if s.keyPath != "" {
key, err := os.ReadFile(s.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", s.keyPath)
}
s.key = key
}
keyPair, err := tls.X509KeyPair(s.certificate, s.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
s.tlsCertificate = &keyPair
s.logger.Info("reloaded TLS certificate")
return nil
}

View File

@ -16,6 +16,7 @@ type _Options struct {
Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"`
MITM *MITMServiceOptions `json:"mitm,omitempty"`
Experimental *ExperimentalOptions `json:"experimental,omitempty"`
}

10
option/mitm.go Normal file
View File

@ -0,0 +1,10 @@
package option
type MITMServiceOptions struct {
Enabled bool `json:"enabled,omitempty"`
Insecure bool `json:"insecure,omitempty"`
Certificate string `json:"certificate,omitempty"`
CertificatePath string `json:"certificate_path,omitempty"`
Key string `json:"key,omitempty"`
KeyPath string `json:"key_path,omitempty"`
}

View File

@ -80,6 +80,7 @@ type DefaultRule struct {
ClashMode string `json:"clash_mode,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
MITM bool `json:"mitm,omitempty"`
}
func (r DefaultRule) IsValid() bool {
@ -94,6 +95,7 @@ type LogicalRule struct {
Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
MITM bool `json:"mitm,omitempty"`
}
func (r LogicalRule) IsValid() bool {

View File

@ -2,6 +2,7 @@ package outbound
import (
"context"
"crypto/tls"
"net"
"net/netip"
"os"
@ -83,6 +84,21 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
}
func NewTLSConnection(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error {
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
} else {
outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
return N.HandshakeFailure(conn, err)
}
return CopyEarlyConn(ctx, conn, tls.Client(outConn, tlsConfig))
}
func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
if cachedReader, isCached := conn.(N.CachedReader); isCached {
payload := cachedReader.ReadCached()

View File

@ -68,7 +68,7 @@ type Router struct {
inboundByTag map[string]adapter.Inbound
outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound
rules []adapter.Rule
rules []adapter.RouteRule
ipRules []adapter.IPRule
defaultDetour string
defaultOutboundForConnection adapter.Outbound
@ -100,6 +100,7 @@ type Router struct {
timeService adapter.TimeService
clashServer adapter.ClashServer
v2rayServer adapter.V2RayServer
mitmService adapter.MITMService
platformInterface platform.Interface
}
@ -127,7 +128,7 @@ func NewRouter(
logger: logFactory.NewLogger("router"),
dnsLogger: logFactory.NewLogger("dns"),
outboundByTag: make(map[string]adapter.Outbound),
rules: make([]adapter.Rule, 0, len(options.Rules)),
rules: make([]adapter.RouteRule, 0, len(options.Rules)),
ipRules: make([]adapter.IPRule, 0, len(options.IPRules)),
dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
@ -683,6 +684,17 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
conn = statsService.RoutedConnection(metadata.Inbound, detour.Tag(), metadata.User, conn)
}
}
if matchedRule != nil && matchedRule.MITM() {
if r.mitmService == nil {
return E.New("MITM disabled")
}
fallbackConn, err := r.mitmService.ProcessConnection(ctx, conn, detour, metadata)
if fallbackConn != nil {
conn = fallbackConn
} else {
return err
}
}
return detour.NewConnection(ctx, conn, metadata)
}
@ -789,7 +801,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
return detour.NewPacketConnection(ctx, conn, metadata)
}
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.RouteRule, adapter.Outbound) {
if r.processSearcher != nil {
var originDestination netip.AddrPort
if metadata.OriginDestination.IsValid() {
@ -865,7 +877,7 @@ func (r *Router) DefaultMark() int {
return r.defaultMark
}
func (r *Router) Rules() []adapter.Rule {
func (r *Router) Rules() []adapter.RouteRule {
return r.rules
}
@ -892,6 +904,14 @@ func (r *Router) TimeFunc() func() time.Time {
return r.timeService.TimeFunc()
}
func (r *Router) MITMService() adapter.MITMService {
return r.mitmService
}
func (r *Router) SetMITMService(service adapter.MITMService) {
r.mitmService = service
}
func (r *Router) ClashServer() adapter.ClashServer {
return r.clashServer
}

View File

@ -8,7 +8,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
)
func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) {
func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.RouteRule, error) {
switch options.Type {
case "", C.RuleTypeDefault:
if !options.DefaultOptions.IsValid() {
@ -31,10 +31,11 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul
}
}
var _ adapter.Rule = (*DefaultRule)(nil)
var _ adapter.RouteRule = (*DefaultRule)(nil)
type DefaultRule struct {
abstractDefaultRule
mitm bool
}
type RuleItem interface {
@ -48,6 +49,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
invert: options.Invert,
outbound: options.Outbound,
},
options.MITM,
}
if len(options.Inbound) > 0 {
item := NewInboundRule(options.Inbound)
@ -187,10 +189,15 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
return rule, nil
}
func (r *DefaultRule) MITM() bool {
return r.mitm
}
var _ adapter.Rule = (*LogicalRule)(nil)
type LogicalRule struct {
abstractLogicalRule
mitm bool
}
func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
@ -200,6 +207,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
invert: options.Invert,
outbound: options.Outbound,
},
options.MITM,
}
switch options.Mode {
case C.LogicalTypeAnd:
@ -218,3 +226,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
}
return r, nil
}
func (r *LogicalRule) MITM() bool {
return r.mitm
}