fix(transport): correctly release UDS locker file (#2305)

* fix(transport): correctly release UDS locker file

* use callback function to do some jobs after create listener
This commit is contained in:
A1lo 2023-08-26 16:23:54 +08:00 committed by yuhan6665
parent 2d5475f428
commit 10d6b06578
5 changed files with 49 additions and 49 deletions

View file

@ -23,7 +23,6 @@ type Listener struct {
handler internet.ConnHandler handler internet.ConnHandler
local net.Addr local net.Addr
config *Config config *Config
locker *internet.FileLocker // for unix domain socket
s *grpc.Server s *grpc.Server
} }
@ -110,10 +109,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i
newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx)) newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
return return
} }
locker := ctx.Value(address.Domain())
if locker != nil {
listener.locker = locker.(*internet.FileLocker)
}
} else { // tcp } else { // tcp
streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),

View file

@ -27,7 +27,6 @@ type Listener struct {
handler internet.ConnHandler handler internet.ConnHandler
local net.Addr local net.Addr
config *Config config *Config
locker *internet.FileLocker // for unix domain socket
} }
func (l *Listener) Addr() net.Addr { func (l *Listener) Addr() net.Addr {
@ -35,9 +34,6 @@ func (l *Listener) Addr() net.Addr {
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {
if l.locker != nil {
l.locker.Release()
}
return l.server.Close() return l.server.Close()
} }
@ -180,10 +176,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx)) newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
return return
} }
locker := ctx.Value(address.Domain())
if locker != nil {
listener.locker = locker.(*internet.FileLocker)
}
} else { // tcp } else { // tcp
streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),

View file

@ -21,6 +21,19 @@ type DefaultListener struct {
controllers []control.Func controllers []control.Func
} }
type combinedListener struct {
net.Listener
locker *FileLocker // for unix domain socket
}
func (cl *combinedListener) Close() error {
if cl.locker != nil {
cl.locker.Release()
cl.locker = nil
}
return cl.Listener.Close()
}
func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error { func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
@ -44,6 +57,10 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) { func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
var lc net.ListenConfig var lc net.ListenConfig
var network, address string var network, address string
// callback is called after the Listen function returns
callback := func(l net.Listener, err error) (net.Listener, error) {
return l, err
}
switch addr := addr.(type) { switch addr := addr.(type) {
case *net.TCPAddr: case *net.TCPAddr:
@ -58,23 +75,6 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
network = addr.Network() network = addr.Network()
address = addr.Name address = addr.Name
if s := strings.Split(address, ","); len(s) == 2 {
address = s[0]
perm, perr := strconv.ParseUint(s[1], 8, 32)
if perr != nil {
return nil, newError("failed to parse permission: " + s[1]).Base(perr)
}
defer func(file string, permission os.FileMode) {
if err == nil {
cerr := os.Chmod(address, permission)
if cerr != nil {
err = newError("failed to set permission for " + file).Base(cerr)
}
}
}(address, os.FileMode(perm))
}
if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' { if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' {
// linux abstract unix domain socket is lockfree // linux abstract unix domain socket is lockfree
if len(address) > 1 && address[1] == '@' { if len(address) > 1 && address[1] == '@' {
@ -84,19 +84,48 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
address = string(fullAddr) address = string(fullAddr)
} }
} else { } else {
// split permission from address
var filePerm *os.FileMode
if s := strings.Split(address, ","); len(s) == 2 {
address = s[0]
perm, perr := strconv.ParseUint(s[1], 8, 32)
if perr != nil {
return nil, newError("failed to parse permission: " + s[1]).Base(perr)
}
mode := os.FileMode(perm)
filePerm = &mode
}
// normal unix domain socket needs lock // normal unix domain socket needs lock
locker := &FileLocker{ locker := &FileLocker{
path: address + ".lock", path: address + ".lock",
} }
err := locker.Acquire() if err := locker.Acquire(); err != nil {
if err != nil {
return nil, err return nil, err
} }
ctx = context.WithValue(ctx, address, locker)
// set callback to combine listener and set permission
callback = func(l net.Listener, err error) (net.Listener, error) {
if err != nil {
locker.Release()
return l, err
}
l = &combinedListener{Listener: l, locker: locker}
if filePerm == nil {
return l, nil
}
err = os.Chmod(address, *filePerm)
if err != nil {
l.Close()
return nil, newError("failed to set permission for " + address).Base(err)
}
return l, nil
}
} }
} }
l, err = lc.Listen(ctx, network, address) l, err = lc.Listen(ctx, network, address)
l, err = callback(l, err)
if sockopt != nil && sockopt.AcceptProxyProtocol { if sockopt != nil && sockopt.AcceptProxyProtocol {
policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil } policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
l = &proxyproto.Listener{Listener: l, Policy: policyFunc} l = &proxyproto.Listener{Listener: l, Policy: policyFunc}

View file

@ -24,7 +24,6 @@ type Listener struct {
authConfig internet.ConnectionAuthenticator authConfig internet.ConnectionAuthenticator
config *Config config *Config
addConn internet.ConnHandler addConn internet.ConnHandler
locker *internet.FileLocker // for unix domain socket
} }
// ListenTCP creates a new Listener based on configurations. // ListenTCP creates a new Listener based on configurations.
@ -51,10 +50,6 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe
return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err) return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
} }
newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx)) newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
locker := ctx.Value(address.Domain())
if locker != nil {
l.locker = locker.(*internet.FileLocker)
}
} else { } else {
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
@ -133,9 +128,6 @@ func (v *Listener) Addr() net.Addr {
// Close implements internet.Listener.Close. // Close implements internet.Listener.Close.
func (v *Listener) Close() error { func (v *Listener) Close() error {
if v.locker != nil {
v.locker.Release()
}
return v.listener.Close() return v.listener.Close()
} }

View file

@ -75,7 +75,6 @@ type Listener struct {
listener net.Listener listener net.Listener
config *Config config *Config
addConn internet.ConnHandler addConn internet.ConnHandler
locker *internet.FileLocker // for unix domain socket
} }
func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
@ -101,10 +100,6 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
} }
newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
locker := ctx.Value(address.Domain())
if locker != nil {
l.locker = locker.(*internet.FileLocker)
}
} else { // tcp } else { // tcp
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
@ -153,9 +148,6 @@ func (ln *Listener) Addr() net.Addr {
// Close implements net.Listener.Close(). // Close implements net.Listener.Close().
func (ln *Listener) Close() error { func (ln *Listener) Close() error {
if ln.locker != nil {
ln.locker.Release()
}
return ln.listener.Close() return ln.listener.Close()
} }