Fix wireguard close

This commit is contained in:
世界 2022-11-06 10:20:23 +08:00
parent b2cd78d279
commit 0ad1bbea11
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -34,6 +34,7 @@ type StackDevice struct {
mtu uint32 mtu uint32
events chan tun.Event events chan tun.Event
outbound chan *stack.PacketBuffer outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
addr4 tcpip.Address addr4 tcpip.Address
addr6 tcpip.Address addr6 tcpip.Address
@ -50,6 +51,7 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
mtu: mtu, mtu: mtu,
events: make(chan tun.Event, 1), events: make(chan tun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256), outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
} }
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
if err != nil { if err != nil {
@ -140,16 +142,20 @@ func (w *StackDevice) File() *os.File {
} }
func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
packetBuffer, ok := <-w.outbound select {
if !ok { case packetBuffer, ok := <-w.outbound:
if !ok {
return 0, os.ErrClosed
}
defer packetBuffer.DecRef()
p = p[offset:]
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
}
return
case <-w.done:
return 0, os.ErrClosed return 0, os.ErrClosed
} }
defer packetBuffer.DecRef()
p = p[offset:]
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
}
return
} }
func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
@ -201,7 +207,7 @@ func (w *StackDevice) Close() error {
endpoint.Abort() endpoint.Abort()
} }
w.stack.Wait() w.stack.Wait()
close(w.outbound) close(w.done)
return nil return nil
} }
@ -246,7 +252,11 @@ func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() { for _, packetBuffer := range list.AsSlice() {
packetBuffer.IncRef() packetBuffer.IncRef()
ep.outbound <- packetBuffer select {
case <-ep.done:
return 0, &tcpip.ErrClosedForSend{}
case ep.outbound <- packetBuffer:
}
} }
return list.Len(), nil return list.Len(), nil
} }