package proxy import ( "crypto/tls" "crypto/x509" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "log" "net" "net/http" "os" "strings" "sync/atomic" ) type Proxy struct { transcoders map[string]Transcoder ml *mitmListener ReadCount uint64 WriteCount uint64 user string pass string host string cert string } type Transcoder interface { Transcode(*ResponseWriter, *ResponseReader, http.Header) error } func New(host string, cert string) *Proxy { p := &Proxy{ transcoders: make(map[string]Transcoder), ml: nil, host: host, cert: cert, } return p } func (p *Proxy) EnableMitm(ca, key string) error { cf, err := newCertFaker(ca, key) if err != nil { return err } var config *tls.Config if p.cert != "" { roots, err := x509.SystemCertPool() if err != nil { return err } pem, err := ioutil.ReadFile(p.cert) if err != nil { return err } ok := roots.AppendCertsFromPEM([]byte(pem)) if !ok { return errors.New("failed to parse root certificate") } config = &tls.Config{RootCAs: roots} } p.ml = newMitmListener(cf, config) go http.Serve(p.ml, p) return nil } func (p *Proxy) SetAuthentication(user, pass string) { p.user = user p.pass = pass } func (p *Proxy) AddTranscoder(contentType string, transcoder Transcoder) { p.transcoders[contentType] = transcoder } func (p *Proxy) Start(host string) error { return http.ListenAndServe(host, p) } func (p *Proxy) StartTLS(host, cert, key string) error { return http.ListenAndServeTLS(host, cert, key, p) } func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("serving request: %s", r.URL) if err := p.handle(w, r); err != nil { log.Printf("%s while serving request: %s", err, r.URL) } } func (p *Proxy) checkHttpBasicAuth(auth string) bool { prefix := "Basic " if !strings.HasPrefix(auth, prefix) { return false } decoded, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) if err != nil { return false } values := strings.SplitN(string(decoded), ":", 2) if len(values) != 2 || values[0] != p.user || values[1] != p.pass { return false } return true } func (p *Proxy) handle(w http.ResponseWriter, r *http.Request) error { // TODO: only HTTPS? if p.user != "" { if !p.checkHttpBasicAuth(r.Header.Get("Proxy-Authorization")) { w.Header().Set("Proxy-Authenticate", "Basic realm=\"Compy\"") w.WriteHeader(http.StatusProxyAuthRequired) return nil } r.Header.Del("Proxy-Authorization") } if r.Method == "CONNECT" { return p.handleConnect(w, r) } host := r.URL.Host if host == "" { host = r.Host } if hostname, err := os.Hostname(); host == p.host || (err == nil && host == hostname+p.host) { return p.handleLocalRequest(w, r) } resp, err := forward(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) return fmt.Errorf("error forwarding request: %s", err) } defer resp.Body.Close() user_agent := r.Header.Get("User-Agent") w.Header().Set("User-Agent", user_agent) rw := newResponseWriter(w) rr := newResponseReader(resp) err = p.proxyResponse(rw, rr, r.Header) read := rr.counter.Count() written := rw.rw.Count() log.Printf("transcoded: %d -> %d (%3.1f%%)", read, written, float64(written)/float64(read)*100) atomic.AddUint64(&p.ReadCount, read) atomic.AddUint64(&p.WriteCount, written) return err } func (p *Proxy) handleLocalRequest(w http.ResponseWriter, r *http.Request) error { if r.Method == "GET" && (r.URL.Path == "" || r.URL.Path == "/") { w.Header().Set("Content-Type", "text/html") read := atomic.LoadUint64(&p.ReadCount) written := atomic.LoadUint64(&p.WriteCount) io.WriteString(w, fmt.Sprintf(` compy

compy

`, read, written, float64(written)/float64(read)*100)) return nil } else if r.Method == "GET" && r.URL.Path == "/cacert" { if p.cert == "" { http.NotFound(w, r) return nil } w.Header().Set("Content-Type", "application/x-x509-ca-cert") http.ServeFile(w, r, p.cert) return nil } else { w.WriteHeader(http.StatusNotImplemented) return nil } } func forward(r *http.Request) (*http.Response, error) { if r.URL.Scheme == "" { if r.TLS != nil && r.TLS.ServerName == r.Host { r.URL.Scheme = "https" } else { r.URL.Scheme = "http" } } if r.URL.Host == "" { r.URL.Host = r.Host } r.RequestURI = "" return http.DefaultTransport.RoundTrip(r) } func (p *Proxy) proxyResponse(w *ResponseWriter, r *ResponseReader, headers http.Header) error { w.takeHeaders(r) transcoder, found := p.transcoders[r.ContentType()] if !found { return w.ReadFrom(r) } w.setChunked() if err := transcoder.Transcode(w, r, headers); err != nil { return fmt.Errorf("transcoding error: %s", err) } return nil } func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) error { if p.ml == nil { return fmt.Errorf("CONNECT received but mitm is not enabled") } w.WriteHeader(http.StatusOK) var conn net.Conn if h, ok := w.(http.Hijacker); ok { conn, _, _ = h.Hijack() } else { fw := w.(FlushWriter) fw.Flush() mconn := newMitmConn(fw, r.Body, r.RemoteAddr) conn = mconn defer func() { <-mconn.closed }() } sconn, err := p.ml.Serve(conn, r.Host) if err != nil { conn.Close() return err } sconn.Close() // TODO: reuse this connection for https requests return nil }