diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index 117ed4c2..210a5869 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -34,6 +34,7 @@ type StackDevice struct { mtu uint32 events chan tun.Event outbound chan *stack.PacketBuffer + done chan struct{} dispatcher stack.NetworkDispatcher addr4 tcpip.Address addr6 tcpip.Address @@ -50,6 +51,7 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er mtu: mtu, events: make(chan tun.Event, 1), outbound: make(chan *stack.PacketBuffer, 256), + done: make(chan struct{}), } err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) if err != nil { @@ -140,16 +142,20 @@ func (w *StackDevice) File() *os.File { } func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { - packetBuffer, ok := <-w.outbound - if !ok { + select { + case packetBuffer, ok := <-w.outbound: + if !ok { + return 0, os.ErrClosed + } + defer packetBuffer.DecRef() + p = p[offset:] + for _, slice := range packetBuffer.AsSlices() { + n += copy(p[n:], slice) + } + return + case <-w.done: return 0, os.ErrClosed } - defer packetBuffer.DecRef() - p = p[offset:] - for _, slice := range packetBuffer.AsSlices() { - n += copy(p[n:], slice) - } - return } func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { @@ -201,7 +207,7 @@ func (w *StackDevice) Close() error { endpoint.Abort() } w.stack.Wait() - close(w.outbound) + close(w.done) return nil } @@ -246,7 +252,11 @@ func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) { func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { for _, packetBuffer := range list.AsSlice() { packetBuffer.IncRef() - ep.outbound <- packetBuffer + select { + case <-ep.done: + return 0, &tcpip.ErrClosedForSend{} + case ep.outbound <- packetBuffer: + } } return list.Len(), nil }