diff --git a/go.mod b/go.mod index 3726c99d..d2fe07c0 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/refraction-networking/utls v1.2.0 github.com/sagernet/sing v0.0.0-20220801112236-1bb95f9661fc github.com/sagernet/sing-shadowsocks v0.0.0-20220801112336-a91eacdd01e1 + github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb github.com/stretchr/testify v1.8.1 github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e @@ -27,6 +28,7 @@ require ( golang.org/x/sys v0.2.0 google.golang.org/grpc v1.51.0 google.golang.org/protobuf v1.28.1 + gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h12.io/socks v1.0.3 ) @@ -36,16 +38,20 @@ require ( github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect github.com/francoispqt/gojay v1.2.13 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/google/btree v1.0.1 // indirect github.com/google/pprof v0.0.0-20221112000123-84eb7ad69597 // indirect github.com/klauspost/compress v1.15.12 // indirect github.com/klauspost/cpuid/v2 v2.2.0 // indirect + github.com/kr/pretty v0.2.1 // indirect github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/onsi/ginkgo/v2 v2.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect + go.uber.org/atomic v1.10.0 // indirect golang.org/x/exp v0.0.0-20221111204811-129d8d6c17ab // indirect golang.org/x/mod v0.7.0 // indirect golang.org/x/text v0.4.0 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.3.0 // indirect google.golang.org/genproto v0.0.0-20221111202108-142d8a6fa32e // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 94fd0f0c..c7b058b5 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -94,8 +96,9 @@ github.com/klauspost/compress v1.15.12/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrD github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.0 h1:4ZexSFt8agMNzNisrsilL6RClWDC5YJnLHNIfTy4iuc= github.com/klauspost/cpuid/v2 v2.2.0/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= @@ -143,6 +146,8 @@ github.com/sagernet/sing v0.0.0-20220801112236-1bb95f9661fc h1:x7H64IiqyrpxPWl/K github.com/sagernet/sing v0.0.0-20220801112236-1bb95f9661fc/go.mod h1:GbtQfZSpmtD3cXeD1qX2LCMwY8dH+bnnInDTqd92IsM= github.com/sagernet/sing-shadowsocks v0.0.0-20220801112336-a91eacdd01e1 h1:RYvOc69eSNMN0dwVugrDts41Nn7Ar/C/n/fvytvFcp4= github.com/sagernet/sing-shadowsocks v0.0.0-20220801112336-a91eacdd01e1/go.mod h1:NqZjiXszgVCMQ4gVDa2V+drhS8NMfGqUqDF86EacEFc= +github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c h1:vK2wyt9aWYHHvNLWniwijBu/n4pySypiKRhN32u/JGo= +github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c/go.mod h1:euOmN6O5kk9dQmgSS8Df4psAl3TCjxOz0NW60EWkSaI= github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb h1:XfLJSPIOUX+osiMraVgIrMR27uMXnRJWGm1+GL8/63U= github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb/go.mod h1:bR6DqgcAl1zTcOX8/pE2Qkj9XO00eCNqmKb7lXP8EAg= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -191,6 +196,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.starlark.net v0.0.0-20221028183056-acb66ad56dd2 h1:5/KzhcSqd4UgY51l17r7C5g/JiE6DRw1Vq7VJfQHuMc= go.starlark.net v0.0.0-20221028183056-acb66ad56dd2/go.mod h1:kIVgS18CjmEC3PqMd5kaJSGEifyV/CeB9x506ZJ1Vbk= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -263,6 +270,8 @@ golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -327,6 +336,8 @@ gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h1:m5lcgWnL3OElQNVyp3qcncItJ2c0sQlSGjYK2+nJTA4= +gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c/go.mod h1:TIvkJD0sxe8pIob3p6T8IzxXunlp6yfgktvTNp+DGNM= h12.io/socks v1.0.3 h1:Ka3qaQewws4j4/eDQnOdpr4wXsC//dXtWvftlIcCQUo= h12.io/socks v1.0.3/go.mod h1:AIhxy1jOId/XCz9BO+EIgNL2rQiPTBNnOfnVnQ+3Eck= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go new file mode 100644 index 00000000..c4dec367 --- /dev/null +++ b/infra/conf/wireguard.go @@ -0,0 +1,111 @@ +package conf + +import ( + "encoding/base64" + "encoding/hex" + + "github.com/golang/protobuf/proto" + "github.com/xtls/xray-core/proxy/wireguard" +) + +type WireGuardPeerConfig struct { + PublicKey string `json:"publicKey"` + PreSharedKey string `json:"preSharedKey"` + Endpoint string `json:"endpoint"` + KeepAlive int `json:"keepAlive"` + AllowedIPs []string `json:"allowedIPs,omitempty"` +} + +func (c *WireGuardPeerConfig) Build() (proto.Message, error) { + var err error + config := new(wireguard.PeerConfig) + + config.PublicKey, err = parseWireGuardKey(c.PublicKey) + if err != nil { + return nil, err + } + + if c.PreSharedKey != "" { + config.PreSharedKey, err = parseWireGuardKey(c.PreSharedKey) + if err != nil { + return nil, err + } + } else { + config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000" + } + + config.Endpoint = c.Endpoint + // default 0 + config.KeepAlive = int32(c.KeepAlive) + if c.AllowedIPs == nil { + config.AllowedIps = []string{"0.0.0.0/0", "::0/0"} + } else { + config.AllowedIps = c.AllowedIPs + } + + return config, nil +} + +type WireGuardConfig struct { + SecretKey string `json:"secretKey"` + Address []string `json:"address"` + Peers []*WireGuardPeerConfig `json:"peers"` + MTU int `json:"mtu"` + NumWorkers int `json:"workers"` +} + +func (c *WireGuardConfig) Build() (proto.Message, error) { + config := new(wireguard.DeviceConfig) + + var err error + config.SecretKey, err = parseWireGuardKey(c.SecretKey) + if err != nil { + return nil, err + } + + if c.Address == nil { + // bogon ips + config.Endpoint = []string{"10.0.0.1", "fd59:7153:2388:b5fd:0000:0000:0000:0001"} + } else { + config.Endpoint = c.Address + } + + if c.Peers != nil { + config.Peers = make([]*wireguard.PeerConfig, len(c.Peers)) + for i, p := range c.Peers { + msg, err := p.Build() + if err != nil { + return nil, err + } + config.Peers[i] = msg.(*wireguard.PeerConfig) + } + } + + if c.MTU == 0 { + config.Mtu = 1420 + } else { + config.Mtu = int32(c.MTU) + } + // these a fallback code exists in github.com/nanoda0523/wireguard-go code, + // we don't need to process fallback manually + config.NumWorkers = int32(c.NumWorkers) + + return config, nil +} + +func parseWireGuardKey(str string) (string, error) { + if len(str) != 64 { + // may in base64 form + dat, err := base64.StdEncoding.DecodeString(str) + if err != nil { + return "", err + } + if len(dat) != 32 { + return "", newError("key should be 32 bytes: " + str) + } + return hex.EncodeToString(dat), err + } else { + // already hex form + return str, nil + } +} diff --git a/infra/conf/wireguard_test.go b/infra/conf/wireguard_test.go new file mode 100644 index 00000000..f0136bf0 --- /dev/null +++ b/infra/conf/wireguard_test.go @@ -0,0 +1,49 @@ +package conf_test + +import ( + "testing" + + . "github.com/xtls/xray-core/infra/conf" + "github.com/xtls/xray-core/proxy/wireguard" +) + +func TestWireGuardOutbound(t *testing.T) { + creator := func() Buildable { + return new(WireGuardConfig) + } + + runMultiTestCase(t, []TestCase{ + { + Input: `{ + "secretKey": "uJv5tZMDltsiYEn+kUwb0Ll/CXWhMkaSCWWhfPEZM3A=", + "address": ["10.1.1.1", "fd59:7153:2388:b5fd:0000:0000:1234:0001"], + "peers": [ + { + "publicKey": "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a", + "endpoint": "127.0.0.1:1234" + } + ], + "mtu": 1300, + "workers": 2 + }`, + Parser: loadJSON(creator), + Output: &wireguard.DeviceConfig{ + // key converted into hex form + SecretKey: "b89bf9b5930396db226049fe914c1bd0b97f0975a13246920965a17cf1193370", + Endpoint: []string{"10.1.1.1", "fd59:7153:2388:b5fd:0000:0000:1234:0001"}, + Peers: []*wireguard.PeerConfig{ + { + // also can read from hex form directly + PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a", + PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000", + Endpoint: "127.0.0.1:1234", + KeepAlive: 0, + AllowedIps: []string{"0.0.0.0/0", "::0/0"}, + }, + }, + Mtu: 1300, + NumWorkers: 2, + }, + }, + }) +} diff --git a/infra/conf/xray.go b/infra/conf/xray.go index 76f00804..cda512da 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -40,6 +40,7 @@ var ( "trojan": func() interface{} { return new(TrojanClientConfig) }, "mtproto": func() interface{} { return new(MTProtoClientConfig) }, "dns": func() interface{} { return new(DNSOutboundConfig) }, + "wireguard": func() interface{} { return new(WireGuardConfig) }, }, "protocol", "settings") ctllog = log.New(os.Stderr, "xctl> ", 0) diff --git a/main/distro/all/all.go b/main/distro/all/all.go index db9d5c4d..f92542d5 100644 --- a/main/distro/all/all.go +++ b/main/distro/all/all.go @@ -48,6 +48,7 @@ import ( _ "github.com/xtls/xray-core/proxy/vless/outbound" _ "github.com/xtls/xray-core/proxy/vmess/inbound" _ "github.com/xtls/xray-core/proxy/vmess/outbound" + _ "github.com/xtls/xray-core/proxy/wireguard" // Transports _ "github.com/xtls/xray-core/transport/internet/domainsocket" diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go new file mode 100644 index 00000000..1136f5ed --- /dev/null +++ b/proxy/wireguard/bind.go @@ -0,0 +1,254 @@ +package wireguard + +import ( + "context" + "errors" + "io" + "net" + "net/netip" + "strconv" + "sync" + + "github.com/sagernet/wireguard-go/conn" + xnet "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/features/dns" + "github.com/xtls/xray-core/transport/internet" +) + +type netReadInfo struct { + // status + waiter sync.WaitGroup + // param + buff []byte + // result + bytes int + endpoint conn.Endpoint + err error +} + +type netBindClient struct { + workers int + dialer internet.Dialer + dns dns.Client + dnsOption dns.IPOption + + readQueue chan *netReadInfo +} + +func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { + ipStr, port, _, err := splitAddrPort(s) + if err != nil { + return nil, err + } + + var addr net.IP + if IsDomainName(ipStr) { + ips, err := n.dns.LookupIP(ipStr, n.dnsOption) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, dns.ErrEmptyResponse + } + addr = ips[0] + } else { + addr = net.ParseIP(ipStr) + } + if addr == nil { + return nil, errors.New("failed to parse ip: " + ipStr) + } + + var ip xnet.Address + if p4 := addr.To4(); len(p4) == net.IPv4len { + ip = xnet.IPAddress(p4[:]) + } else { + ip = xnet.IPAddress(addr[:]) + } + + dst := xnet.Destination{ + Address: ip, + Port: xnet.Port(port), + Network: xnet.Network_UDP, + } + + return &netEndpoint{ + dst: dst, + }, nil +} + +func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { + bind.readQueue = make(chan *netReadInfo) + + fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) { + defer func() { + if r := recover(); r != nil { + cap = 0 + ep = nil + err = errors.New("channel closed") + } + }() + + r := &netReadInfo{ + buff: buff, + } + r.waiter.Add(1) + bind.readQueue <- r + r.waiter.Wait() // wait read goroutine done, or we will miss the result + return r.bytes, r.endpoint, r.err + } + workers := bind.workers + if workers <= 0 { + workers = 1 + } + arr := make([]conn.ReceiveFunc, workers) + for i := 0; i < workers; i++ { + arr[i] = fun + } + + return arr, uint16(uport), nil +} + +func (bind *netBindClient) Close() error { + if bind.readQueue != nil { + close(bind.readQueue) + } + return nil +} + +func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { + c, err := bind.dialer.Dial(context.Background(), endpoint.dst) + if err != nil { + return err + } + endpoint.conn = c + + go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { + for { + v, ok := <-readQueue + if !ok { + return + } + i, err := c.Read(v.buff) + v.bytes = i + v.endpoint = endpoint + v.err = err + v.waiter.Done() + if err != nil && errors.Is(err, io.EOF) { + endpoint.conn = nil + return + } + } + }(bind.readQueue, endpoint) + + return nil +} + +func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error { + var err error + + nend, ok := endpoint.(*netEndpoint) + if !ok { + return conn.ErrWrongEndpointType + } + + if nend.conn == nil { + err = bind.connectTo(nend) + if err != nil { + return err + } + } + + _, err = nend.conn.Write(buff) + + return err +} + +func (bind *netBindClient) SetMark(mark uint32) error { + return nil +} + +type netEndpoint struct { + dst xnet.Destination + conn net.Conn +} + +func (netEndpoint) ClearSrc() {} + +func (e netEndpoint) DstIP() netip.Addr { + return toNetIpAddr(e.dst.Address) +} + +func (e netEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e netEndpoint) DstToBytes() []byte { + var dat []byte + if e.dst.Address.Family().IsIPv4() { + dat = e.dst.Address.IP().To4()[:] + } else { + dat = e.dst.Address.IP().To16()[:] + } + dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) + return dat +} + +func (e netEndpoint) DstToString() string { + return e.dst.NetAddr() +} + +func (e netEndpoint) SrcToString() string { + return "" +} + +func toNetIpAddr(addr xnet.Address) netip.Addr { + if addr.Family().IsIPv4() { + ip := addr.IP() + return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) + } else { + ip := addr.IP() + arr := [16]byte{} + for i := 0; i < 16; i++ { + arr[i] = ip[i] + } + return netip.AddrFrom16(arr) + } +} + +func stringsLastIndexByte(s string, b byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == b { + return i + } + } + return -1 +} + +func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { + i := stringsLastIndexByte(s, ':') + if i == -1 { + return "", 0, false, errors.New("not an ip:port") + } + + ip = s[:i] + portStr := s[i+1:] + if len(ip) == 0 { + return "", 0, false, errors.New("no IP") + } + if len(portStr) == 0 { + return "", 0, false, errors.New("no port") + } + port64, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s)) + } + port = uint16(port64) + if ip[0] == '[' { + if len(ip) < 2 || ip[len(ip)-1] != ']' { + return "", 0, false, errors.New("missing ]") + } + ip = ip[1 : len(ip)-1] + v6 = true + } + + return ip, port, v6, nil +} diff --git a/proxy/wireguard/config.pb.go b/proxy/wireguard/config.pb.go new file mode 100644 index 00000000..149fa958 --- /dev/null +++ b/proxy/wireguard/config.pb.go @@ -0,0 +1,294 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.9 +// source: proxy/wireguard/config.proto + +package wireguard + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type PeerConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` + PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"` + Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"` + KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` + AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"` +} + +func (x *PeerConfig) Reset() { + *x = PeerConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_wireguard_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PeerConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PeerConfig) ProtoMessage() {} + +func (x *PeerConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_wireguard_config_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead. +func (*PeerConfig) Descriptor() ([]byte, []int) { + return file_proxy_wireguard_config_proto_rawDescGZIP(), []int{0} +} + +func (x *PeerConfig) GetPublicKey() string { + if x != nil { + return x.PublicKey + } + return "" +} + +func (x *PeerConfig) GetPreSharedKey() string { + if x != nil { + return x.PreSharedKey + } + return "" +} + +func (x *PeerConfig) GetEndpoint() string { + if x != nil { + return x.Endpoint + } + return "" +} + +func (x *PeerConfig) GetKeepAlive() int32 { + if x != nil { + return x.KeepAlive + } + return 0 +} + +func (x *PeerConfig) GetAllowedIps() []string { + if x != nil { + return x.AllowedIps + } + return nil +} + +type DeviceConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SecretKey string `protobuf:"bytes,1,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"` + Endpoint []string `protobuf:"bytes,2,rep,name=endpoint,proto3" json:"endpoint,omitempty"` + Peers []*PeerConfig `protobuf:"bytes,3,rep,name=peers,proto3" json:"peers,omitempty"` + Mtu int32 `protobuf:"varint,4,opt,name=mtu,proto3" json:"mtu,omitempty"` + NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"` +} + +func (x *DeviceConfig) Reset() { + *x = DeviceConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_wireguard_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeviceConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeviceConfig) ProtoMessage() {} + +func (x *DeviceConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_wireguard_config_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeviceConfig.ProtoReflect.Descriptor instead. +func (*DeviceConfig) Descriptor() ([]byte, []int) { + return file_proxy_wireguard_config_proto_rawDescGZIP(), []int{1} +} + +func (x *DeviceConfig) GetSecretKey() string { + if x != nil { + return x.SecretKey + } + return "" +} + +func (x *DeviceConfig) GetEndpoint() []string { + if x != nil { + return x.Endpoint + } + return nil +} + +func (x *DeviceConfig) GetPeers() []*PeerConfig { + if x != nil { + return x.Peers + } + return nil +} + +func (x *DeviceConfig) GetMtu() int32 { + if x != nil { + return x.Mtu + } + return 0 +} + +func (x *DeviceConfig) GetNumWorkers() int32 { + if x != nil { + return x.NumWorkers + } + return 0 +} + +var File_proxy_wireguard_config_proto protoreflect.FileDescriptor + +var file_proxy_wireguard_config_proto_rawDesc = []byte{ + 0x0a, 0x1c, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, + 0x64, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, + 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, + 0x75, 0x61, 0x72, 0x64, 0x22, 0xad, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, + 0x65, 0x79, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x72, 0x65, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64, + 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, + 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69, + 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c, + 0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69, + 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, + 0x64, 0x49, 0x70, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65, + 0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x36, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x20, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, + 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x75, + 0x6d, 0x5f, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x0a, 0x6e, 0x75, 0x6d, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x73, 0x42, 0x5e, 0x0a, 0x18, 0x63, + 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, + 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, + 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, + 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, + 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +} + +var ( + file_proxy_wireguard_config_proto_rawDescOnce sync.Once + file_proxy_wireguard_config_proto_rawDescData = file_proxy_wireguard_config_proto_rawDesc +) + +func file_proxy_wireguard_config_proto_rawDescGZIP() []byte { + file_proxy_wireguard_config_proto_rawDescOnce.Do(func() { + file_proxy_wireguard_config_proto_rawDescData = protoimpl.X.CompressGZIP(file_proxy_wireguard_config_proto_rawDescData) + }) + return file_proxy_wireguard_config_proto_rawDescData +} + +var file_proxy_wireguard_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_proxy_wireguard_config_proto_goTypes = []interface{}{ + (*PeerConfig)(nil), // 0: xray.proxy.wireguard.PeerConfig + (*DeviceConfig)(nil), // 1: xray.proxy.wireguard.DeviceConfig +} +var file_proxy_wireguard_config_proto_depIdxs = []int32{ + 0, // 0: xray.proxy.wireguard.DeviceConfig.peers:type_name -> xray.proxy.wireguard.PeerConfig + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_proxy_wireguard_config_proto_init() } +func file_proxy_wireguard_config_proto_init() { + if File_proxy_wireguard_config_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_proxy_wireguard_config_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PeerConfig); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_wireguard_config_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeviceConfig); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_proxy_wireguard_config_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_proxy_wireguard_config_proto_goTypes, + DependencyIndexes: file_proxy_wireguard_config_proto_depIdxs, + MessageInfos: file_proxy_wireguard_config_proto_msgTypes, + }.Build() + File_proxy_wireguard_config_proto = out.File + file_proxy_wireguard_config_proto_rawDesc = nil + file_proxy_wireguard_config_proto_goTypes = nil + file_proxy_wireguard_config_proto_depIdxs = nil +} diff --git a/proxy/wireguard/config.proto b/proxy/wireguard/config.proto new file mode 100644 index 00000000..dde3b41b --- /dev/null +++ b/proxy/wireguard/config.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package xray.proxy.wireguard; +option csharp_namespace = "Xray.Proxy.WireGuard"; +option go_package = "github.com/xtls/xray-core/proxy/wireguard"; +option java_package = "com.xray.proxy.wireguard"; +option java_multiple_files = true; + +message PeerConfig { + string public_key = 1; + string pre_shared_key = 2; + string endpoint = 3; + int32 keep_alive = 4; + repeated string allowed_ips = 5; +} + +message DeviceConfig { + string secret_key = 1; + repeated string endpoint = 2; + repeated PeerConfig peers = 3; + int32 mtu = 4; + int32 num_workers = 5; +} \ No newline at end of file diff --git a/proxy/wireguard/errors.generated.go b/proxy/wireguard/errors.generated.go new file mode 100644 index 00000000..8319e07d --- /dev/null +++ b/proxy/wireguard/errors.generated.go @@ -0,0 +1,9 @@ +package wireguard + +import "github.com/xtls/xray-core/common/errors" + +type errPathObjHolder struct{} + +func newError(values ...interface{}) *errors.Error { + return errors.New(values...).WithPathObj(errPathObjHolder{}) +} diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go new file mode 100644 index 00000000..4d1cb7f6 --- /dev/null +++ b/proxy/wireguard/tun.go @@ -0,0 +1,303 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. + */ + +package wireguard + +import ( + "context" + "fmt" + "net" + "net/netip" + "os" + + "github.com/sagernet/wireguard-go/tun" + "github.com/xtls/xray-core/features/dns" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + incomingPacket chan *bufferv2.View + mtu int + dnsClient dns.Client + hasV4, hasV6 bool +} + +type Net netTun + +func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + HandleLocal: true, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 10), + incomingPacket: make(chan *bufferv2.View), + dnsClient: dnsClient, + mtu: mtu, + } + dev.ep.AddNotify(dev) + tcpipErr := dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + + dev.events <- tun.EventUp + return dev, (*Net)(dev), nil +} + +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +func (tun *netTun) File() *os.File { + return nil +} + +func (tun *netTun) Events() chan tun.Event { + return tun.events +} + +func (tun *netTun) Read(buf []byte, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + return view.Read(buf[offset:]) +} + +func (tun *netTun) Write(buf []byte, offset int) (int, error) { + packet := buf[offset:] + if len(packet) == 0 { + return 0, nil + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + } + + return len(buf), nil +} + +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt == nil { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +func (tun *netTun) Flush() error { + return nil +} + +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + if tun.events != nil { + close(tun.events) + } + + tun.ep.Close() + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) +} + +func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) +} + +func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { + if addr == nil { + return net.ListenTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} + +func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { + return net.DialUDPAddrPort(laddr, netip.AddrPort{}) +} + +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + ip, _ := netip.AddrFromSlice(laddr.IP) + la = netip.AddrPortFrom(ip, uint16(laddr.Port)) + } + if raddr != nil { + ip, _ := netip.AddrFromSlice(raddr.IP) + ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + +func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { + return net.DialUDP(laddr, nil) +} + +func (n *Net) HasV4() bool { + return n.hasV4 +} + +func (n *Net) HasV6() bool { + return n.hasV6 +} + +func IsDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go new file mode 100644 index 00000000..51cee876 --- /dev/null +++ b/proxy/wireguard/wireguard.go @@ -0,0 +1,263 @@ +/* + +Some of codes are copied from https://github.com/octeep/wireproxy, license below. + +Copyright (c) 2022 Wind T.F. Wong + +Permission to use, copy, modify, and distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +*/ + +package wireguard + +import ( + "bytes" + "context" + "fmt" + "net/netip" + "strings" + + "github.com/sagernet/wireguard-go/device" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/log" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" + "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/features/dns" + "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/internet" +) + +// Handler is an outbound connection that silently swallow the entire payload. +type Handler struct { + conf *DeviceConfig + net *Net + bind *netBindClient + policyManager policy.Manager + dns dns.Client + // cached configuration + ipc string + endpoints []netip.Addr +} + +// New creates a new wireguard handler. +func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { + v := core.MustFromContext(ctx) + + endpoints, err := parseEndpoints(conf) + if err != nil { + return nil, err + } + + return &Handler{ + conf: conf, + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + dns: v.GetFeature(dns.ClientType()).(dns.Client), + ipc: createIPCRequest(conf), + endpoints: endpoints, + }, nil +} + +// Process implements OutboundHandler.Dispatch(). +func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + if h.bind == nil || h.bind.dialer != dialer || h.net == nil { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Info, + Content: "switching dialer", + }) + // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer + bind := &netBindClient{ + dialer: dialer, + workers: int(h.conf.NumWorkers), + dns: h.dns, + } + + net, err := h.makeVirtualTun(bind) + if err != nil { + bind.Close() + return newError("failed to create virtual tun interface").Base(err) + } + + h.net = net + if h.bind != nil { + h.bind.Close() + } + h.bind = bind + } + + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified") + } + // Destination of the inner request. + destination := outbound.Target + command := protocol.RequestCommandTCP + if destination.Network == net.Network_UDP { + command = protocol.RequestCommandUDP + } + + // resolve dns + addr := destination.Address + if addr.Family().IsDomain() { + ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ + IPv4Enable: h.net.HasV4(), + IPv6Enable: h.net.HasV6(), + }) + if err != nil { + return newError("failed to lookup DNS").Base(err) + } else if len(ips) == 0 { + return dns.ErrEmptyResponse + } + addr = net.IPAddress(ips[0]) + } + + p := h.policyManager.ForLevel(0) + + ctx, cancel := context.WithCancel(ctx) + timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) + addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) + + var requestFunc func() error + var responseFunc func() error + + if command == protocol.RequestCommandTCP { + conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) + if err != nil { + return newError("failed to create TCP connection").Base(err) + } + + requestFunc = func() error { + defer timer.SetTimeout(p.Timeouts.DownlinkOnly) + return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) + } + responseFunc = func() error { + defer timer.SetTimeout(p.Timeouts.UplinkOnly) + return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + } + } else if command == protocol.RequestCommandUDP { + conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) + if err != nil { + return newError("failed to create UDP connection").Base(err) + } + + requestFunc = func() error { + defer timer.SetTimeout(p.Timeouts.DownlinkOnly) + return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) + } + responseFunc = func() error { + defer timer.SetTimeout(p.Timeouts.UplinkOnly) + return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + } + } + + responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) + if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { + return newError("connection ends").Base(err) + } + + return nil +} + +// serialize the config into an IPC request +func createIPCRequest(conf *DeviceConfig) string { + var request bytes.Buffer + + request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) + + for _, peer := range conf.Peers { + request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", + peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey)) + + for _, ip := range peer.AllowedIps { + request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) + } + } + + return request.String()[:request.Len()] +} + +// convert endpoint string to netip.Addr +func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { + endpoints := make([]netip.Addr, len(conf.Endpoint)) + for i, str := range conf.Endpoint { + var addr netip.Addr + if strings.Contains(str, "/") { + prefix, err := netip.ParsePrefix(str) + if err != nil { + return nil, err + } + addr = prefix.Addr() + if prefix.Bits() != addr.BitLen() { + return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") + } + } else { + var err error + addr, err = netip.ParseAddr(str) + if err != nil { + return nil, err + } + } + endpoints[i] = addr + } + + return endpoints, nil +} + +// creates a tun interface on netstack given a configuration +func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) { + tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu)) + if err != nil { + return nil, err + } + + bind.dnsOption.IPv4Enable = tnet.HasV4() + bind.dnsOption.IPv6Enable = tnet.HasV6() + + // dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */) + dev := device.NewDevice(tun, bind, &device.Logger{ + Verbosef: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Debug, + Content: fmt.Sprintf(format, args...), + }) + }, + Errorf: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Error, + Content: fmt.Sprintf(format, args...), + }) + }, + }, int(h.conf.NumWorkers)) + err = dev.IpcSet(h.ipc) + if err != nil { + return nil, err + } + + err = dev.Up() + if err != nil { + return nil, err + } + + return tnet, nil +} + +func init() { + common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { + return New(ctx, config.(*DeviceConfig)) + })) +}