package proxy import ( "crypto/tls" "crypto/x509" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "log" "net" "net/http" "os" "strings" "sync/atomic" ) type Proxy struct { transcoders map[string]Transcoder ml *mitmListener ReadCount uint64 WriteCount uint64 user string pass string host string cert *string } type Transcoder interface { Transcode(*ResponseWriter, *ResponseReader, http.Header) error } func New(host string, cert *string) *Proxy { p := &Proxy{ transcoders: make(map[string]Transcoder), ml: nil, host: host, cert: cert, } return p } func (p *Proxy) EnableMitm(ca, key string) error { cf, err := newCertFaker(ca, key) if err != nil { return err } var config *tls.Config if p.cert != nil { roots, err := x509.SystemCertPool() if err != nil { return err } pem, err := ioutil.ReadFile(*p.cert) if err != nil { return err } ok := roots.AppendCertsFromPEM([]byte(pem)) if !ok { return errors.New("failed to parse root certificate") } config = &tls.Config{RootCAs: roots} } p.ml = newMitmListener(cf, config) go http.Serve(p.ml, p) return nil } func (p *Proxy) SetAuthentication(user, pass string) { p.user = user p.pass = pass } func (p *Proxy) AddTranscoder(contentType string, transcoder Transcoder) { p.transcoders[contentType] = transcoder } func (p *Proxy) Start(host string) error { return http.ListenAndServe(host, p) } func (p *Proxy) StartTLS(host, cert, key string) error { return http.ListenAndServeTLS(host, cert, key, p) } func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("serving request: %s", r.URL) if err := p.handle(w, r); err != nil { log.Printf("%s while serving request: %s", err, r.URL) } } func (p *Proxy) checkHttpBasicAuth(auth string) bool { prefix := "Basic " if !strings.HasPrefix(auth, prefix) { return false } decoded, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) if err != nil { return false } values := strings.SplitN(string(decoded), ":", 2) if len(values) != 2 || values[0] != p.user || values[1] != p.pass { return false } return true } func (p *Proxy) handle(w http.ResponseWriter, r *http.Request) error { // TODO: only HTTPS? if p.user != "" { if !p.checkHttpBasicAuth(r.Header.Get("Proxy-Authorization")) { w.Header().Set("WWW-Authenticate", "Basic realm=\"Compy\"") w.WriteHeader(http.StatusProxyAuthRequired) return nil } } if r.Method == "CONNECT" { return p.handleConnect(w, r) } host := r.URL.Host if host == "" { host = r.Host } if hostname, err := os.Hostname(); host == p.host || (err == nil && host == hostname+p.host) { return p.handleLocalRequest(w, r) } resp, err := forward(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) return fmt.Errorf("error forwarding request: %s", err) } defer resp.Body.Close() rw := newResponseWriter(w) rr := newResponseReader(resp) err = p.proxyResponse(rw, rr, r.Header) read := rr.counter.Count() written := rw.rw.Count() log.Printf("transcoded: %d -> %d (%3.1f%%)", read, written, float64(written)/float64(read)*100) atomic.AddUint64(&p.ReadCount, read) atomic.AddUint64(&p.WriteCount, written) return err } func (p *Proxy) handleLocalRequest(w http.ResponseWriter, r *http.Request) error { if r.Method == "GET" && (r.URL.Path == "" || r.URL.Path == "/") { w.Header().Set("Content-Type", "text/html") read := atomic.LoadUint64(&p.ReadCount) written := atomic.LoadUint64(&p.WriteCount) io.WriteString(w, fmt.Sprintf(`