package transport import ( "bytes" "context" "io" "net" "net/http" "net/url" "strconv" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" sHTTP "github.com/sagernet/sing/protocol/http" mDNS "github.com/miekg/dns" "golang.org/x/net/http2" ) const MimeType = "application/dns-message" var _ adapter.DNSTransport = (*HTTPSTransport)(nil) func RegisterHTTPS(registry *dns.TransportRegistry) { dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTPS, NewHTTPS) } type HTTPSTransport struct { dns.TransportAdapter logger logger.ContextLogger dialer N.Dialer destination *url.URL headers http.Header transport *http.Transport } func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions) if err != nil { return nil, err } tlsOptions := common.PtrValueOrDefault(options.TLS) tlsOptions.Enabled = true tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions) if err != nil { return nil, err } if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) { tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS)) } if !common.Contains(tlsConfig.NextProtos(), "http/1.1") { tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1")) } headers := options.Headers.Build() host := headers.Get("Host") if host != "" { headers.Del("Host") } else { if tlsConfig.ServerName() != "" { host = tlsConfig.ServerName() } else { host = options.Server } } destinationURL := url.URL{ Scheme: "https", Host: host, } if destinationURL.Host == "" { destinationURL.Host = options.Server } if options.ServerPort != 0 && options.ServerPort != 443 { destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort))) } path := options.Path if path == "" { path = "/dns-query" } err = sHTTP.URLSetPath(&destinationURL, path) if err != nil { return nil, err } serverAddr := options.ServerOptions.Build() if serverAddr.Port == 0 { serverAddr.Port = 443 } return NewHTTPSRaw( dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTPS, tag, options.RemoteDNSServerOptions), logger, transportDialer, &destinationURL, headers, serverAddr, tlsConfig, ), nil } func NewHTTPSRaw( adapter dns.TransportAdapter, logger log.ContextLogger, dialer N.Dialer, destination *url.URL, headers http.Header, serverAddr M.Socksaddr, tlsConfig tls.Config, ) *HTTPSTransport { var transport *http.Transport if tlsConfig != nil { transport = &http.Transport{ ForceAttemptHTTP2: true, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr) if hErr != nil { return nil, hErr } tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig) if hErr != nil { tcpConn.Close() return nil, hErr } return tlsConn, nil }, } } else { transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.DialContext(ctx, network, serverAddr) }, } } return &HTTPSTransport{ TransportAdapter: adapter, logger: logger, dialer: dialer, destination: destination, headers: headers, transport: transport, } } func (t *HTTPSTransport) Reset() { t.transport.CloseIdleConnections() t.transport = t.transport.Clone() } func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { exMessage := *message exMessage.Id = 0 exMessage.Compress = true requestBuffer := buf.NewSize(1 + message.Len()) rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes()) if err != nil { requestBuffer.Release() return nil, err } request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage)) if err != nil { requestBuffer.Release() return nil, err } request.Header = t.headers.Clone() request.Header.Set("Content-Type", MimeType) request.Header.Set("Accept", MimeType) response, err := t.transport.RoundTrip(request) requestBuffer.Release() if err != nil { return nil, err } defer response.Body.Close() if response.StatusCode != http.StatusOK { return nil, E.New("unexpected status: ", response.Status) } var responseMessage mDNS.Msg if response.ContentLength > 0 { responseBuffer := buf.NewSize(int(response.ContentLength)) _, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength)) if err != nil { return nil, err } err = responseMessage.Unpack(responseBuffer.Bytes()) responseBuffer.Release() } else { rawMessage, err = io.ReadAll(response.Body) if err != nil { return nil, err } err = responseMessage.Unpack(rawMessage) } if err != nil { return nil, err } return &responseMessage, nil }