diff --git a/outbound/default.go b/outbound/default.go index 2f147a41..81251176 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -39,6 +39,10 @@ func (a *myOutboundAdapter) Network() []string { return a.network } +func (a *myOutboundAdapter) NewError(ctx context.Context, err error) { + NewError(a.logger, ctx, err) +} + func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { ctx = adapter.WithContext(ctx, &metadata) var outConn net.Conn @@ -121,3 +125,12 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro } return bufio.CopyConn(ctx, conn, serverConn) } + +func NewError(logger log.ContextLogger, ctx context.Context, err error) { + common.Close(err) + if E.IsClosedOrCanceled(err) { + logger.DebugContext(ctx, "connection closed: ", err) + return + } + logger.ErrorContext(ctx, err) +} diff --git a/outbound/wireguard.go b/outbound/wireguard.go index 9ca3a16a..c956dade 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -64,7 +64,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context connectAddr = options.ServerOptions.Build() } } - outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved) + outbound.bind = wireguard.NewClientBind(ctx, outbound, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved) localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build) if len(localPrefixes) == 0 { return nil, E.New("missing local address") diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 26fe9967..79c08d47 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -7,8 +7,8 @@ import ( "sync" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/wireguard-go/conn" @@ -18,6 +18,7 @@ var _ conn.Bind = (*ClientBind)(nil) type ClientBind struct { ctx context.Context + errorHandler E.Handler dialer N.Dialer reservedForEndpoint map[M.Socksaddr][3]uint8 connAccess sync.Mutex @@ -28,9 +29,10 @@ type ClientBind struct { reserved [3]uint8 } -func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind { +func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind { return &ClientBind{ ctx: ctx, + errorHandler: errorHandler, dialer: dialer, reservedForEndpoint: make(map[M.Socksaddr][3]uint8), isConnect: isConnect, @@ -67,10 +69,10 @@ func (c *ClientBind) connect() (*wireConn, error) { if c.isConnect { udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr) if err != nil { - return nil, &wireError{err} + return nil, err } c.conn = &wireConn{ - NetPacketConn: &bufio.UnbindPacketConn{ + PacketConn: &bufio.UnbindPacketConn{ ExtendedConn: bufio.NewExtendedConn(udpConn), Addr: c.connectAddr, }, @@ -79,11 +81,11 @@ func (c *ClientBind) connect() (*wireConn, error) { } else { udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) if err != nil { - return nil, &wireError{err} + return nil, err } c.conn = &wireConn{ - NetPacketConn: bufio.NewPacketConn(udpConn), - done: make(chan struct{}), + PacketConn: bufio.NewPacketConn(udpConn), + done: make(chan struct{}), } } return c.conn, nil @@ -102,30 +104,31 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1 func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { udpConn, err := c.connect() if err != nil { - err = &wireError{err} + select { + case <-c.done: + return + default: + } + c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server")) + err = nil return } - buffer := buf.With(b) - destination, err := udpConn.ReadPacket(buffer) + n, addr, err := udpConn.ReadFrom(b) if err != nil { udpConn.Close() select { case <-c.done: default: - err = &wireError{err} + c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet")) } return } - n = buffer.Len() - if buffer.Start() > 0 { - copy(b, buffer.Bytes()) - } if n > 3 { b[1] = 0 b[2] = 0 b[3] = 0 } - ep = Endpoint(destination) + ep = Endpoint(M.SocksaddrFromNet(addr)) return } @@ -167,7 +170,7 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error { b[2] = reserved[1] b[3] = reserved[2] } - err = udpConn.WritePacket(buf.As(b), destination) + _, err = udpConn.WriteTo(b, destination) if err != nil { udpConn.Close() } @@ -179,7 +182,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { } type wireConn struct { - N.NetPacketConn + net.PacketConn access sync.Mutex done chan struct{} } @@ -192,7 +195,7 @@ func (w *wireConn) Close() error { return net.ErrClosed default: } - w.NetPacketConn.Close() + w.PacketConn.Close() close(w.done) return nil } diff --git a/transport/wireguard/error.go b/transport/wireguard/error.go deleted file mode 100644 index db99a1b5..00000000 --- a/transport/wireguard/error.go +++ /dev/null @@ -1,22 +0,0 @@ -package wireguard - -import "net" - -type wireError struct { - cause error -} - -func (w *wireError) Error() string { - return w.cause.Error() -} - -func (w *wireError) Timeout() bool { - if cause, causeNet := w.cause.(net.Error); causeNet { - return cause.Timeout() - } - return false -} - -func (w *wireError) Temporary() bool { - return true -}