From 10d6b065784efd3f33a02d6d5ad2a1fa162ff346 Mon Sep 17 00:00:00 2001 From: A1lo Date: Sat, 26 Aug 2023 16:23:54 +0800 Subject: [PATCH] fix(transport): correctly release UDS locker file (#2305) * fix(transport): correctly release UDS locker file * use callback function to do some jobs after create listener --- transport/internet/grpc/hub.go | 5 -- transport/internet/http/hub.go | 8 ---- transport/internet/system_listener.go | 69 +++++++++++++++++++-------- transport/internet/tcp/hub.go | 8 ---- transport/internet/websocket/hub.go | 8 ---- 5 files changed, 49 insertions(+), 49 deletions(-) diff --git a/transport/internet/grpc/hub.go b/transport/internet/grpc/hub.go index d3dd6da5..e55f6f77 100644 --- a/transport/internet/grpc/hub.go +++ b/transport/internet/grpc/hub.go @@ -23,7 +23,6 @@ type Listener struct { handler internet.ConnHandler local net.Addr config *Config - locker *internet.FileLocker // for unix domain socket s *grpc.Server } @@ -110,10 +109,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx)) return } - locker := ctx.Value(address.Domain()) - if locker != nil { - listener.locker = locker.(*internet.FileLocker) - } } else { // tcp streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), diff --git a/transport/internet/http/hub.go b/transport/internet/http/hub.go index 551f897e..f0260460 100644 --- a/transport/internet/http/hub.go +++ b/transport/internet/http/hub.go @@ -27,7 +27,6 @@ type Listener struct { handler internet.ConnHandler local net.Addr config *Config - locker *internet.FileLocker // for unix domain socket } func (l *Listener) Addr() net.Addr { @@ -35,9 +34,6 @@ func (l *Listener) Addr() net.Addr { } func (l *Listener) Close() error { - if l.locker != nil { - l.locker.Release() - } return l.server.Close() } @@ -180,10 +176,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx)) return } - locker := ctx.Value(address.Domain()) - if locker != nil { - listener.locker = locker.(*internet.FileLocker) - } } else { // tcp streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index 60979062..1d635897 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -21,6 +21,19 @@ type DefaultListener struct { controllers []control.Func } +type combinedListener struct { + net.Listener + locker *FileLocker // for unix domain socket +} + +func (cl *combinedListener) Close() error { + if cl.locker != nil { + cl.locker.Release() + cl.locker = nil + } + return cl.Listener.Close() +} + func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { @@ -44,6 +57,10 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) { var lc net.ListenConfig var network, address string + // callback is called after the Listen function returns + callback := func(l net.Listener, err error) (net.Listener, error) { + return l, err + } switch addr := addr.(type) { case *net.TCPAddr: @@ -58,23 +75,6 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S network = addr.Network() address = addr.Name - if s := strings.Split(address, ","); len(s) == 2 { - address = s[0] - perm, perr := strconv.ParseUint(s[1], 8, 32) - if perr != nil { - return nil, newError("failed to parse permission: " + s[1]).Base(perr) - } - - defer func(file string, permission os.FileMode) { - if err == nil { - cerr := os.Chmod(address, permission) - if cerr != nil { - err = newError("failed to set permission for " + file).Base(cerr) - } - } - }(address, os.FileMode(perm)) - } - if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' { // linux abstract unix domain socket is lockfree if len(address) > 1 && address[1] == '@' { @@ -84,19 +84,48 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S address = string(fullAddr) } } else { + // split permission from address + var filePerm *os.FileMode + if s := strings.Split(address, ","); len(s) == 2 { + address = s[0] + perm, perr := strconv.ParseUint(s[1], 8, 32) + if perr != nil { + return nil, newError("failed to parse permission: " + s[1]).Base(perr) + } + + mode := os.FileMode(perm) + filePerm = &mode + } // normal unix domain socket needs lock locker := &FileLocker{ path: address + ".lock", } - err := locker.Acquire() - if err != nil { + if err := locker.Acquire(); err != nil { return nil, err } - ctx = context.WithValue(ctx, address, locker) + + // set callback to combine listener and set permission + callback = func(l net.Listener, err error) (net.Listener, error) { + if err != nil { + locker.Release() + return l, err + } + l = &combinedListener{Listener: l, locker: locker} + if filePerm == nil { + return l, nil + } + err = os.Chmod(address, *filePerm) + if err != nil { + l.Close() + return nil, newError("failed to set permission for " + address).Base(err) + } + return l, nil + } } } l, err = lc.Listen(ctx, network, address) + l, err = callback(l, err) if sockopt != nil && sockopt.AcceptProxyProtocol { policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil } l = &proxyproto.Listener{Listener: l, Policy: policyFunc} diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 392228c6..d4b4f8b5 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -24,7 +24,6 @@ type Listener struct { authConfig internet.ConnectionAuthenticator config *Config addConn internet.ConnHandler - locker *internet.FileLocker // for unix domain socket } // ListenTCP creates a new Listener based on configurations. @@ -51,10 +50,6 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err) } newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx)) - locker := ctx.Value(address.Domain()) - if locker != nil { - l.locker = locker.(*internet.FileLocker) - } } else { listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), @@ -133,9 +128,6 @@ func (v *Listener) Addr() net.Addr { // Close implements internet.Listener.Close. func (v *Listener) Close() error { - if v.locker != nil { - v.locker.Release() - } return v.listener.Close() } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index c0cf3446..7951b1f4 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -75,7 +75,6 @@ type Listener struct { listener net.Listener config *Config addConn internet.ConnHandler - locker *internet.FileLocker // for unix domain socket } func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { @@ -101,10 +100,6 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) } newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) - locker := ctx.Value(address.Domain()) - if locker != nil { - l.locker = locker.(*internet.FileLocker) - } } else { // tcp listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), @@ -153,9 +148,6 @@ func (ln *Listener) Addr() net.Addr { // Close implements net.Listener.Close(). func (ln *Listener) Close() error { - if ln.locker != nil { - ln.locker.Release() - } return ln.listener.Close() }