[proxy] Keep custom TCP listeners alive after mapping batches (#6415)
This commit is contained in:
@@ -1105,7 +1105,7 @@ func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp
|
||||
router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc)
|
||||
router.SetObserver(s.meter)
|
||||
router.SetAccessLogger(s.accessLog)
|
||||
portCtx, cancel := context.WithCancel(ctx)
|
||||
portCtx, cancel := context.WithCancel(s.portRouterContext(ctx))
|
||||
|
||||
s.portRouters[port] = &portRouter{
|
||||
router: router,
|
||||
@@ -1121,10 +1121,26 @@ func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp
|
||||
}
|
||||
}()
|
||||
|
||||
s.Logger.Debugf("started per-port router on %s", listenAddr)
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"port": port,
|
||||
"listen_addr": listenAddr,
|
||||
"bound_addr": ln.Addr().String(),
|
||||
"proxy_protocol": s.ProxyProtocol,
|
||||
}).Info("custom TCP listener started")
|
||||
return router, nil
|
||||
}
|
||||
|
||||
// portRouterContext returns the server-lifetime context for custom TCP
|
||||
// listeners. Mapping-batch contexts are cancelled after a batch is applied; a
|
||||
// per-port listener must outlive that batch and only stop on service removal or
|
||||
// server shutdown.
|
||||
func (s *Server) portRouterContext(ctx context.Context) context.Context {
|
||||
if s.ctx != nil {
|
||||
return s.ctx
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// cleanupPortIfEmpty tears down a per-port router if it has no remaining
|
||||
// routes or fallback. The main port is never cleaned up. Active relay
|
||||
// connections are drained before the listener is closed.
|
||||
@@ -1718,6 +1734,13 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
|
||||
s.meter.L4ServiceAdded(types.ServiceModeTCP)
|
||||
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
|
||||
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"domain": mapping.GetDomain(),
|
||||
"target": targetAddr,
|
||||
"port": port,
|
||||
"service": svcID,
|
||||
}).Info("TCP mapping added")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,14 +3,20 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -202,3 +208,117 @@ func TestRedactMappingForLog_HandlesEmptyOrNilFields(t *testing.T) {
|
||||
assert.Nil(t, redacted.Auth, "nil Auth must remain nil")
|
||||
assert.Empty(t, redacted.Path, "empty Path must remain empty")
|
||||
}
|
||||
|
||||
type statusUpdateOnlyClient struct {
|
||||
proto.ProxyServiceClient
|
||||
}
|
||||
|
||||
func (statusUpdateOnlyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
func TestSetupTCPMappingBindsCustomListenPort(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // test port allocated by the OS
|
||||
require.NoError(t, ln.Close())
|
||||
|
||||
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := &Server{
|
||||
Logger: quietLifecycleLogger(),
|
||||
mgmtClient: statusUpdateOnlyClient{},
|
||||
meter: meter,
|
||||
mainPort: 8443,
|
||||
portRouters: make(map[uint16]*portRouter),
|
||||
svcPorts: make(map[types.ServiceID][]uint16),
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
srv.portMu.Lock()
|
||||
for p, pr := range srv.portRouters {
|
||||
pr.cancel()
|
||||
require.NoError(t, pr.listener.Close())
|
||||
delete(srv.portRouters, p)
|
||||
}
|
||||
srv.portMu.Unlock()
|
||||
srv.portRouterWg.Wait()
|
||||
})
|
||||
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-tcp",
|
||||
AccountId: "acct-1",
|
||||
Domain: "ssh.example.com",
|
||||
Mode: "tcp",
|
||||
ListenPort: int32(port),
|
||||
Path: []*proto.PathMapping{
|
||||
{Target: "10.0.0.5:22"},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, srv.setupTCPMapping(context.Background(), mapping))
|
||||
|
||||
srv.portMu.RLock()
|
||||
pr := srv.portRouters[port]
|
||||
ports := append([]uint16(nil), srv.svcPorts[types.ServiceID("svc-tcp")]...)
|
||||
srv.portMu.RUnlock()
|
||||
|
||||
require.NotNil(t, pr, "custom TCP mapping must create a per-port router")
|
||||
assert.Equal(t, []uint16{port}, ports, "service must track the custom listen port for cleanup")
|
||||
|
||||
second, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err == nil {
|
||||
_ = second.Close()
|
||||
}
|
||||
require.Error(t, err, "custom TCP listen port must be bound after setup")
|
||||
}
|
||||
|
||||
func TestCustomTCPPortRouterOutlivesMappingBatchContext(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // test port allocated by the OS
|
||||
require.NoError(t, ln.Close())
|
||||
|
||||
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
|
||||
require.NoError(t, err)
|
||||
|
||||
srvCtx, srvCancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(srvCancel)
|
||||
|
||||
srv := &Server{
|
||||
ctx: srvCtx,
|
||||
Logger: quietLifecycleLogger(),
|
||||
meter: meter,
|
||||
mainPort: 8443,
|
||||
portRouters: make(map[uint16]*portRouter),
|
||||
svcPorts: make(map[types.ServiceID][]uint16),
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
srv.portMu.Lock()
|
||||
for p, pr := range srv.portRouters {
|
||||
pr.cancel()
|
||||
if err := pr.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
delete(srv.portRouters, p)
|
||||
}
|
||||
srv.portMu.Unlock()
|
||||
srv.portRouterWg.Wait()
|
||||
})
|
||||
|
||||
batchCtx, cancelBatch := context.WithCancel(context.Background())
|
||||
_, err = srv.getOrCreatePortRouter(batchCtx, port)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancelBatch()
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
second, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err == nil {
|
||||
_ = second.Close()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}, 200*time.Millisecond, 10*time.Millisecond, "custom TCP listener must outlive mapping-batch context cancellation")
|
||||
}
|
||||
|
||||
@@ -81,6 +81,95 @@ func TestIntegration_SyncMappings_HappyPath(t *testing.T) {
|
||||
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
|
||||
}
|
||||
|
||||
func TestIntegration_SyncMappings_CustomTCPMappingDeliveredWithCapabilities(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
tcpSvc := &service.Service{
|
||||
ID: "tcp-custom",
|
||||
AccountID: "test-account-1",
|
||||
Name: "Custom TCP",
|
||||
Domain: "ssh.test.proxy.io",
|
||||
ProxyCluster: "test.proxy.io",
|
||||
Mode: "tcp",
|
||||
ListenPort: 10001,
|
||||
Enabled: true,
|
||||
Targets: []*service.Target{{
|
||||
Host: "10.0.0.5",
|
||||
Port: 22,
|
||||
Protocol: "tcp",
|
||||
TargetId: "peer-ssh",
|
||||
TargetType: "peer",
|
||||
Enabled: true,
|
||||
}},
|
||||
}
|
||||
require.NoError(t, setup.store.CreateService(ctx, tcpSvc))
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
receiveSnapshot := func(proxyID string, caps *proto.ProxyCapabilities) map[string]*proto.ProxyMapping {
|
||||
t.Helper()
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.SyncMappings(streamCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Init{
|
||||
Init: &proto.SyncMappingsInit{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: "test.proxy.io",
|
||||
Capabilities: caps,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
for _, m := range msg.GetMapping() {
|
||||
mappingsByID[m.GetId()] = m
|
||||
}
|
||||
|
||||
err = stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappingsByID
|
||||
}
|
||||
|
||||
legacyMappings := receiveSnapshot("sync-proxy-no-capabilities", nil)
|
||||
assert.NotContains(t, legacyMappings, "tcp-custom",
|
||||
"legacy proxies that do not report capabilities must not receive TCP custom-port mappings")
|
||||
|
||||
supportsCustomPorts := true
|
||||
modernMappings := receiveSnapshot("sync-proxy-custom-ports", &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &supportsCustomPorts,
|
||||
})
|
||||
|
||||
tcpMapping := modernMappings["tcp-custom"]
|
||||
require.NotNil(t, tcpMapping, "capability-aware proxy must receive TCP custom-port mapping")
|
||||
assert.Equal(t, "tcp", tcpMapping.GetMode())
|
||||
assert.Equal(t, int32(10001), tcpMapping.GetListenPort())
|
||||
require.Len(t, tcpMapping.GetPath(), 1)
|
||||
assert.Equal(t, "10.0.0.5:22", tcpMapping.GetPath()[0].GetTarget())
|
||||
assert.NotEmpty(t, tcpMapping.GetAuthToken(), "snapshot mapping must include per-proxy auth token")
|
||||
}
|
||||
|
||||
func TestIntegration_SyncMappings_BackPressure(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user