package http import ( "bufio" "bytes" "context" "io" "net" "net/http" "strings" "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" ) const ( // CRLF is the line ending in HTTP header CRLF = "\r\n" // ENDING is the double line ending between HTTP header and body. ENDING = CRLF + CRLF // max length of HTTP header. Safety precaution for DDoS attack. maxHeaderLength = 8192 ) var ( ErrHeaderToLong = errors.New("Header too long.") ErrHeaderMisMatch = errors.New("Header Mismatch.") ) type Reader interface { Read(io.Reader) (*buf.Buffer, error) } type Writer interface { Write(io.Writer) error } type NoOpReader struct{} func (NoOpReader) Read(io.Reader) (*buf.Buffer, error) { return nil, nil } type NoOpWriter struct{} func (NoOpWriter) Write(io.Writer) error { return nil } type HeaderReader struct { req *http.Request expectedHeader *RequestConfig } func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader { h.expectedHeader = expectedHeader return h } func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) { buffer := buf.New() totalBytes := int32(0) endingDetected := false var headerBuf bytes.Buffer for totalBytes < maxHeaderLength { _, err := buffer.ReadFrom(reader) if err != nil { buffer.Release() return nil, err } if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 { headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING)))) buffer.Advance(int32(n + len(ENDING))) endingDetected = true break } lenEnding := int32(len(ENDING)) if buffer.Len() >= lenEnding { totalBytes += buffer.Len() - lenEnding headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding)) leftover := buffer.BytesFrom(-lenEnding) buffer.Clear() copy(buffer.Extend(lenEnding), leftover) if _, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes()))); err != io.ErrUnexpectedEOF { return nil, err } } } if !endingDetected { buffer.Release() return nil, ErrHeaderToLong } if h.expectedHeader == nil { if buffer.IsEmpty() { buffer.Release() return nil, nil } return buffer, nil } // Parse the request if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes()))); err != nil { return nil, err } else { h.req = req } // Check req path := h.req.URL.Path hasThisURI := false for _, u := range h.expectedHeader.Uri { if u == path { hasThisURI = true } } if !hasThisURI { return nil, ErrHeaderMisMatch } if buffer.IsEmpty() { buffer.Release() return nil, nil } return buffer, nil } type HeaderWriter struct { header *buf.Buffer } func NewHeaderWriter(header *buf.Buffer) *HeaderWriter { return &HeaderWriter{ header: header, } } func (w *HeaderWriter) Write(writer io.Writer) error { if w.header == nil { return nil } err := buf.WriteAllBytes(writer, w.header.Bytes(), nil) w.header.Release() w.header = nil return err } type Conn struct { net.Conn readBuffer *buf.Buffer oneTimeReader Reader oneTimeWriter Writer errorWriter Writer errorMismatchWriter Writer errorTooLongWriter Writer errReason error } func NewConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *Conn { return &Conn{ Conn: conn, oneTimeReader: reader, oneTimeWriter: writer, errorWriter: errorWriter, errorMismatchWriter: errorMismatchWriter, errorTooLongWriter: errorTooLongWriter, } } func (c *Conn) Read(b []byte) (int, error) { if c.oneTimeReader != nil { buffer, err := c.oneTimeReader.Read(c.Conn) if err != nil { c.errReason = err return 0, err } c.readBuffer = buffer c.oneTimeReader = nil } if !c.readBuffer.IsEmpty() { nBytes, _ := c.readBuffer.Read(b) if c.readBuffer.IsEmpty() { c.readBuffer.Release() c.readBuffer = nil } return nBytes, nil } return c.Conn.Read(b) } // Write implements io.Writer. func (c *Conn) Write(b []byte) (int, error) { if c.oneTimeWriter != nil { err := c.oneTimeWriter.Write(c.Conn) c.oneTimeWriter = nil if err != nil { return 0, err } } return c.Conn.Write(b) } // Close implements net.Conn.Close(). func (c *Conn) Close() error { if c.oneTimeWriter != nil && c.errorWriter != nil { // Connection is being closed but header wasn't sent. This means the client request // is probably not valid. Sending back a server error header in this case. // Write response based on error reason switch c.errReason { case ErrHeaderMisMatch: c.errorMismatchWriter.Write(c.Conn) case ErrHeaderToLong: c.errorTooLongWriter.Write(c.Conn) default: c.errorWriter.Write(c.Conn) } } return c.Conn.Close() } func formResponseHeader(config *ResponseConfig) *HeaderWriter { header := buf.New() common.Must2(header.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " "))) common.Must2(header.WriteString(CRLF)) headers := config.PickHeaders() for _, h := range headers { common.Must2(header.WriteString(h)) common.Must2(header.WriteString(CRLF)) } if !config.HasHeader("Date") { common.Must2(header.WriteString("Date: ")) common.Must2(header.WriteString(time.Now().Format(http.TimeFormat))) common.Must2(header.WriteString(CRLF)) } common.Must2(header.WriteString(CRLF)) return &HeaderWriter{ header: header, } } type Authenticator struct { config *Config } func (a Authenticator) GetClientWriter() *HeaderWriter { header := buf.New() config := a.config.Request common.Must2(header.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickURI(), config.GetFullVersion()}, " "))) common.Must2(header.WriteString(CRLF)) headers := config.PickHeaders() for _, h := range headers { common.Must2(header.WriteString(h)) common.Must2(header.WriteString(CRLF)) } common.Must2(header.WriteString(CRLF)) return &HeaderWriter{ header: header, } } func (a Authenticator) GetServerWriter() *HeaderWriter { return formResponseHeader(a.config.Response) } func (a Authenticator) Client(conn net.Conn) net.Conn { if a.config.Request == nil && a.config.Response == nil { return conn } var reader Reader = NoOpReader{} if a.config.Request != nil { reader = new(HeaderReader) } var writer Writer = NoOpWriter{} if a.config.Response != nil { writer = a.GetClientWriter() } return NewConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{}) } func (a Authenticator) Server(conn net.Conn) net.Conn { if a.config.Request == nil && a.config.Response == nil { return conn } return NewConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(), formResponseHeader(resp400), formResponseHeader(resp404), formResponseHeader(resp400)) } func NewAuthenticator(ctx context.Context, config *Config) (Authenticator, error) { return Authenticator{ config: config, }, nil } func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { return NewAuthenticator(ctx, config.(*Config)) })) }