[proxy] Keep custom TCP listeners alive after mapping batches (#6415)
This commit is contained in:
@@ -1105,7 +1105,7 @@ func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp
|
|||||||
router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc)
|
router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc)
|
||||||
router.SetObserver(s.meter)
|
router.SetObserver(s.meter)
|
||||||
router.SetAccessLogger(s.accessLog)
|
router.SetAccessLogger(s.accessLog)
|
||||||
portCtx, cancel := context.WithCancel(ctx)
|
portCtx, cancel := context.WithCancel(s.portRouterContext(ctx))
|
||||||
|
|
||||||
s.portRouters[port] = &portRouter{
|
s.portRouters[port] = &portRouter{
|
||||||
router: router,
|
router: router,
|
||||||
@@ -1121,10 +1121,26 @@ func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s.Logger.Debugf("started per-port router on %s", listenAddr)
|
s.Logger.WithFields(log.Fields{
|
||||||
|
"port": port,
|
||||||
|
"listen_addr": listenAddr,
|
||||||
|
"bound_addr": ln.Addr().String(),
|
||||||
|
"proxy_protocol": s.ProxyProtocol,
|
||||||
|
}).Info("custom TCP listener started")
|
||||||
return router, nil
|
return router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// portRouterContext returns the server-lifetime context for custom TCP
|
||||||
|
// listeners. Mapping-batch contexts are cancelled after a batch is applied; a
|
||||||
|
// per-port listener must outlive that batch and only stop on service removal or
|
||||||
|
// server shutdown.
|
||||||
|
func (s *Server) portRouterContext(ctx context.Context) context.Context {
|
||||||
|
if s.ctx != nil {
|
||||||
|
return s.ctx
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupPortIfEmpty tears down a per-port router if it has no remaining
|
// cleanupPortIfEmpty tears down a per-port router if it has no remaining
|
||||||
// routes or fallback. The main port is never cleaned up. Active relay
|
// routes or fallback. The main port is never cleaned up. Active relay
|
||||||
// connections are drained before the listener is closed.
|
// connections are drained before the listener is closed.
|
||||||
@@ -1718,6 +1734,13 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
|||||||
|
|
||||||
s.meter.L4ServiceAdded(types.ServiceModeTCP)
|
s.meter.L4ServiceAdded(types.ServiceModeTCP)
|
||||||
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
|
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
|
||||||
|
|
||||||
|
s.Logger.WithFields(log.Fields{
|
||||||
|
"domain": mapping.GetDomain(),
|
||||||
|
"target": targetAddr,
|
||||||
|
"port": port,
|
||||||
|
"service": svcID,
|
||||||
|
}).Info("TCP mapping added")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,14 +3,20 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -202,3 +208,117 @@ func TestRedactMappingForLog_HandlesEmptyOrNilFields(t *testing.T) {
|
|||||||
assert.Nil(t, redacted.Auth, "nil Auth must remain nil")
|
assert.Nil(t, redacted.Auth, "nil Auth must remain nil")
|
||||||
assert.Empty(t, redacted.Path, "empty Path must remain empty")
|
assert.Empty(t, redacted.Path, "empty Path must remain empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type statusUpdateOnlyClient struct {
|
||||||
|
proto.ProxyServiceClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statusUpdateOnlyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||||
|
return &proto.SendStatusUpdateResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupTCPMappingBindsCustomListenPort(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
port := uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // test port allocated by the OS
|
||||||
|
require.NoError(t, ln.Close())
|
||||||
|
|
||||||
|
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srv := &Server{
|
||||||
|
Logger: quietLifecycleLogger(),
|
||||||
|
mgmtClient: statusUpdateOnlyClient{},
|
||||||
|
meter: meter,
|
||||||
|
mainPort: 8443,
|
||||||
|
portRouters: make(map[uint16]*portRouter),
|
||||||
|
svcPorts: make(map[types.ServiceID][]uint16),
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
srv.portMu.Lock()
|
||||||
|
for p, pr := range srv.portRouters {
|
||||||
|
pr.cancel()
|
||||||
|
require.NoError(t, pr.listener.Close())
|
||||||
|
delete(srv.portRouters, p)
|
||||||
|
}
|
||||||
|
srv.portMu.Unlock()
|
||||||
|
srv.portRouterWg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
mapping := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "svc-tcp",
|
||||||
|
AccountId: "acct-1",
|
||||||
|
Domain: "ssh.example.com",
|
||||||
|
Mode: "tcp",
|
||||||
|
ListenPort: int32(port),
|
||||||
|
Path: []*proto.PathMapping{
|
||||||
|
{Target: "10.0.0.5:22"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, srv.setupTCPMapping(context.Background(), mapping))
|
||||||
|
|
||||||
|
srv.portMu.RLock()
|
||||||
|
pr := srv.portRouters[port]
|
||||||
|
ports := append([]uint16(nil), srv.svcPorts[types.ServiceID("svc-tcp")]...)
|
||||||
|
srv.portMu.RUnlock()
|
||||||
|
|
||||||
|
require.NotNil(t, pr, "custom TCP mapping must create a per-port router")
|
||||||
|
assert.Equal(t, []uint16{port}, ports, "service must track the custom listen port for cleanup")
|
||||||
|
|
||||||
|
second, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||||
|
if err == nil {
|
||||||
|
_ = second.Close()
|
||||||
|
}
|
||||||
|
require.Error(t, err, "custom TCP listen port must be bound after setup")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomTCPPortRouterOutlivesMappingBatchContext(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
port := uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // test port allocated by the OS
|
||||||
|
require.NoError(t, ln.Close())
|
||||||
|
|
||||||
|
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srvCtx, srvCancel := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(srvCancel)
|
||||||
|
|
||||||
|
srv := &Server{
|
||||||
|
ctx: srvCtx,
|
||||||
|
Logger: quietLifecycleLogger(),
|
||||||
|
meter: meter,
|
||||||
|
mainPort: 8443,
|
||||||
|
portRouters: make(map[uint16]*portRouter),
|
||||||
|
svcPorts: make(map[types.ServiceID][]uint16),
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
srv.portMu.Lock()
|
||||||
|
for p, pr := range srv.portRouters {
|
||||||
|
pr.cancel()
|
||||||
|
if err := pr.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
delete(srv.portRouters, p)
|
||||||
|
}
|
||||||
|
srv.portMu.Unlock()
|
||||||
|
srv.portRouterWg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
batchCtx, cancelBatch := context.WithCancel(context.Background())
|
||||||
|
_, err = srv.getOrCreatePortRouter(batchCtx, port)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cancelBatch()
|
||||||
|
|
||||||
|
assert.Never(t, func() bool {
|
||||||
|
second, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||||
|
if err == nil {
|
||||||
|
_ = second.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}, 200*time.Millisecond, 10*time.Millisecond, "custom TCP listener must outlive mapping-batch context cancellation")
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,6 +81,95 @@ func TestIntegration_SyncMappings_HappyPath(t *testing.T) {
|
|||||||
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
|
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIntegration_SyncMappings_CustomTCPMappingDeliveredWithCapabilities(t *testing.T) {
|
||||||
|
setup := setupIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
tcpSvc := &service.Service{
|
||||||
|
ID: "tcp-custom",
|
||||||
|
AccountID: "test-account-1",
|
||||||
|
Name: "Custom TCP",
|
||||||
|
Domain: "ssh.test.proxy.io",
|
||||||
|
ProxyCluster: "test.proxy.io",
|
||||||
|
Mode: "tcp",
|
||||||
|
ListenPort: 10001,
|
||||||
|
Enabled: true,
|
||||||
|
Targets: []*service.Target{{
|
||||||
|
Host: "10.0.0.5",
|
||||||
|
Port: 22,
|
||||||
|
Protocol: "tcp",
|
||||||
|
TargetId: "peer-ssh",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
require.NoError(t, setup.store.CreateService(ctx, tcpSvc))
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
receiveSnapshot := func(proxyID string, caps *proto.ProxyCapabilities) map[string]*proto.ProxyMapping {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
streamCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stream, err := client.SyncMappings(streamCtx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = stream.Send(&proto.SyncMappingsRequest{
|
||||||
|
Msg: &proto.SyncMappingsRequest_Init{
|
||||||
|
Init: &proto.SyncMappingsInit{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: "test.proxy.io",
|
||||||
|
Capabilities: caps,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||||
|
for {
|
||||||
|
msg, err := stream.Recv()
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, m := range msg.GetMapping() {
|
||||||
|
mappingsByID[m.GetId()] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
err = stream.Send(&proto.SyncMappingsRequest{
|
||||||
|
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mappingsByID
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyMappings := receiveSnapshot("sync-proxy-no-capabilities", nil)
|
||||||
|
assert.NotContains(t, legacyMappings, "tcp-custom",
|
||||||
|
"legacy proxies that do not report capabilities must not receive TCP custom-port mappings")
|
||||||
|
|
||||||
|
supportsCustomPorts := true
|
||||||
|
modernMappings := receiveSnapshot("sync-proxy-custom-ports", &proto.ProxyCapabilities{
|
||||||
|
SupportsCustomPorts: &supportsCustomPorts,
|
||||||
|
})
|
||||||
|
|
||||||
|
tcpMapping := modernMappings["tcp-custom"]
|
||||||
|
require.NotNil(t, tcpMapping, "capability-aware proxy must receive TCP custom-port mapping")
|
||||||
|
assert.Equal(t, "tcp", tcpMapping.GetMode())
|
||||||
|
assert.Equal(t, int32(10001), tcpMapping.GetListenPort())
|
||||||
|
require.Len(t, tcpMapping.GetPath(), 1)
|
||||||
|
assert.Equal(t, "10.0.0.5:22", tcpMapping.GetPath()[0].GetTarget())
|
||||||
|
assert.NotEmpty(t, tcpMapping.GetAuthToken(), "snapshot mapping must include per-proxy auth token")
|
||||||
|
}
|
||||||
|
|
||||||
func TestIntegration_SyncMappings_BackPressure(t *testing.T) {
|
func TestIntegration_SyncMappings_BackPressure(t *testing.T) {
|
||||||
setup := setupIntegrationTest(t)
|
setup := setupIntegrationTest(t)
|
||||||
defer setup.cleanup()
|
defer setup.cleanup()
|
||||||
|
|||||||
Reference in New Issue
Block a user