[proxy] Keep custom TCP listeners alive after mapping batches (#6415)

This commit is contained in:
Lee Sang Hoon
2026-06-15 19:21:24 +09:00
committed by GitHub
parent cd777395f2
commit 60067619a1
3 changed files with 234 additions and 2 deletions

View File

@@ -1105,7 +1105,7 @@ func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp
router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc) router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc)
router.SetObserver(s.meter) router.SetObserver(s.meter)
router.SetAccessLogger(s.accessLog) router.SetAccessLogger(s.accessLog)
portCtx, cancel := context.WithCancel(ctx) portCtx, cancel := context.WithCancel(s.portRouterContext(ctx))
s.portRouters[port] = &portRouter{ s.portRouters[port] = &portRouter{
router: router, router: router,
@@ -1121,10 +1121,26 @@ func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp
} }
}() }()
s.Logger.Debugf("started per-port router on %s", listenAddr) s.Logger.WithFields(log.Fields{
"port": port,
"listen_addr": listenAddr,
"bound_addr": ln.Addr().String(),
"proxy_protocol": s.ProxyProtocol,
}).Info("custom TCP listener started")
return router, nil return router, nil
} }
// portRouterContext returns the server-lifetime context for custom TCP
// listeners. Mapping-batch contexts are cancelled after a batch is applied; a
// per-port listener must outlive that batch and only stop on service removal or
// server shutdown.
func (s *Server) portRouterContext(ctx context.Context) context.Context {
if s.ctx != nil {
return s.ctx
}
return ctx
}
// cleanupPortIfEmpty tears down a per-port router if it has no remaining // cleanupPortIfEmpty tears down a per-port router if it has no remaining
// routes or fallback. The main port is never cleaned up. Active relay // routes or fallback. The main port is never cleaned up. Active relay
// connections are drained before the listener is closed. // connections are drained before the listener is closed.
@@ -1718,6 +1734,13 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
s.meter.L4ServiceAdded(types.ServiceModeTCP) s.meter.L4ServiceAdded(types.ServiceModeTCP)
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
s.Logger.WithFields(log.Fields{
"domain": mapping.GetDomain(),
"target": targetAddr,
"port": port,
"service": svcID,
}).Info("TCP mapping added")
return nil return nil
} }

View File

@@ -3,14 +3,20 @@ package proxy
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
"net"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"google.golang.org/grpc"
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -202,3 +208,117 @@ func TestRedactMappingForLog_HandlesEmptyOrNilFields(t *testing.T) {
assert.Nil(t, redacted.Auth, "nil Auth must remain nil") assert.Nil(t, redacted.Auth, "nil Auth must remain nil")
assert.Empty(t, redacted.Path, "empty Path must remain empty") assert.Empty(t, redacted.Path, "empty Path must remain empty")
} }
type statusUpdateOnlyClient struct {
proto.ProxyServiceClient
}
func (statusUpdateOnlyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
return &proto.SendStatusUpdateResponse{}, nil
}
func TestSetupTCPMappingBindsCustomListenPort(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // test port allocated by the OS
require.NoError(t, ln.Close())
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
require.NoError(t, err)
srv := &Server{
Logger: quietLifecycleLogger(),
mgmtClient: statusUpdateOnlyClient{},
meter: meter,
mainPort: 8443,
portRouters: make(map[uint16]*portRouter),
svcPorts: make(map[types.ServiceID][]uint16),
}
t.Cleanup(func() {
srv.portMu.Lock()
for p, pr := range srv.portRouters {
pr.cancel()
require.NoError(t, pr.listener.Close())
delete(srv.portRouters, p)
}
srv.portMu.Unlock()
srv.portRouterWg.Wait()
})
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-tcp",
AccountId: "acct-1",
Domain: "ssh.example.com",
Mode: "tcp",
ListenPort: int32(port),
Path: []*proto.PathMapping{
{Target: "10.0.0.5:22"},
},
}
require.NoError(t, srv.setupTCPMapping(context.Background(), mapping))
srv.portMu.RLock()
pr := srv.portRouters[port]
ports := append([]uint16(nil), srv.svcPorts[types.ServiceID("svc-tcp")]...)
srv.portMu.RUnlock()
require.NotNil(t, pr, "custom TCP mapping must create a per-port router")
assert.Equal(t, []uint16{port}, ports, "service must track the custom listen port for cleanup")
second, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err == nil {
_ = second.Close()
}
require.Error(t, err, "custom TCP listen port must be bound after setup")
}
func TestCustomTCPPortRouterOutlivesMappingBatchContext(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // test port allocated by the OS
require.NoError(t, ln.Close())
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
require.NoError(t, err)
srvCtx, srvCancel := context.WithCancel(context.Background())
t.Cleanup(srvCancel)
srv := &Server{
ctx: srvCtx,
Logger: quietLifecycleLogger(),
meter: meter,
mainPort: 8443,
portRouters: make(map[uint16]*portRouter),
svcPorts: make(map[types.ServiceID][]uint16),
}
t.Cleanup(func() {
srv.portMu.Lock()
for p, pr := range srv.portRouters {
pr.cancel()
if err := pr.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
require.NoError(t, err)
}
delete(srv.portRouters, p)
}
srv.portMu.Unlock()
srv.portRouterWg.Wait()
})
batchCtx, cancelBatch := context.WithCancel(context.Background())
_, err = srv.getOrCreatePortRouter(batchCtx, port)
require.NoError(t, err)
cancelBatch()
assert.Never(t, func() bool {
second, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err == nil {
_ = second.Close()
return true
}
return false
}, 200*time.Millisecond, 10*time.Millisecond, "custom TCP listener must outlive mapping-batch context cancellation")
}

View File

@@ -81,6 +81,95 @@ func TestIntegration_SyncMappings_HappyPath(t *testing.T) {
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain()) assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
} }
func TestIntegration_SyncMappings_CustomTCPMappingDeliveredWithCapabilities(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
ctx := context.Background()
tcpSvc := &service.Service{
ID: "tcp-custom",
AccountID: "test-account-1",
Name: "Custom TCP",
Domain: "ssh.test.proxy.io",
ProxyCluster: "test.proxy.io",
Mode: "tcp",
ListenPort: 10001,
Enabled: true,
Targets: []*service.Target{{
Host: "10.0.0.5",
Port: 22,
Protocol: "tcp",
TargetId: "peer-ssh",
TargetType: "peer",
Enabled: true,
}},
}
require.NoError(t, setup.store.CreateService(ctx, tcpSvc))
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
receiveSnapshot := func(proxyID string, caps *proto.ProxyCapabilities) map[string]*proto.ProxyMapping {
t.Helper()
streamCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.SyncMappings(streamCtx)
require.NoError(t, err)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: proxyID,
Version: "test-v1",
Address: "test.proxy.io",
Capabilities: caps,
},
},
})
require.NoError(t, err)
mappingsByID := make(map[string]*proto.ProxyMapping)
for {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
})
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
return mappingsByID
}
legacyMappings := receiveSnapshot("sync-proxy-no-capabilities", nil)
assert.NotContains(t, legacyMappings, "tcp-custom",
"legacy proxies that do not report capabilities must not receive TCP custom-port mappings")
supportsCustomPorts := true
modernMappings := receiveSnapshot("sync-proxy-custom-ports", &proto.ProxyCapabilities{
SupportsCustomPorts: &supportsCustomPorts,
})
tcpMapping := modernMappings["tcp-custom"]
require.NotNil(t, tcpMapping, "capability-aware proxy must receive TCP custom-port mapping")
assert.Equal(t, "tcp", tcpMapping.GetMode())
assert.Equal(t, int32(10001), tcpMapping.GetListenPort())
require.Len(t, tcpMapping.GetPath(), 1)
assert.Equal(t, "10.0.0.5:22", tcpMapping.GetPath()[0].GetTarget())
assert.NotEmpty(t, tcpMapping.GetAuthToken(), "snapshot mapping must include per-proxy auth token")
}
func TestIntegration_SyncMappings_BackPressure(t *testing.T) { func TestIntegration_SyncMappings_BackPressure(t *testing.T) {
setup := setupIntegrationTest(t) setup := setupIntegrationTest(t)
defer setup.cleanup() defer setup.cleanup()