[client] Use fwmark-aware route lookup for raw socket UDP checksum source (#6070)

* Use fwmark-aware route lookup for raw socket UDP checksum source

* Guard nil raw socket in sharedsock WriteTo
This commit is contained in:
Viktor Liu
2026-05-05 23:18:22 +09:00
committed by GitHub
parent cd8e71002f
commit 31395f8bd2

View File

@@ -10,15 +10,13 @@ import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/mdlayher/socket"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
@@ -37,8 +35,6 @@ type SharedSocket struct {
conn6 *socket.Conn
port int
mtu uint16
routerMux sync.RWMutex
router routing.Router
packetDemux chan rcvdPacket
cancel context.CancelFunc
}
@@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
}
}()
rawSock.router, err = netroute.New()
if err != nil {
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
}
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil {
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
@@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
go rawSock.read(rawSock.conn6.Recvfrom)
}
go rawSock.updateRouter()
return rawSock, nil
}
// updateRouter updates the listener routing table client
// this is needed to avoid outdated information across different client networks
func (s *SharedSocket) updateRouter() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
router, err := netroute.New()
if err != nil {
log.Errorf("Failed to create and update packet router for stunListener: %s", err)
continue
}
s.routerMux.Lock()
s.router = router
s.routerMux.Unlock()
// resolveSrc returns the source IP the kernel will pick for a packet sent to
// dst by these raw sockets, mirroring the fwmark the kernel will see on send.
func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) {
opts := &netlink.RouteGetOptions{}
if nbnet.AdvancedRouting() {
opts.Mark = nbnet.ControlPlaneMark
}
routes, err := netlink.RouteGetWithOptions(dst, opts)
if err != nil {
return nil, fmt.Errorf("route get %s: %w", dst, err)
}
for _, r := range routes {
if r.Src != nil {
return r.Src, nil
}
}
return nil, fmt.Errorf("no source IP for %s", dst)
}
// LocalAddr returns the local address, preferring IPv4 for backward compatibility.
@@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
DstPort: layers.UDPPort(rUDPAddr.Port),
}
s.routerMux.RLock()
defer s.routerMux.RUnlock()
_, _, src, err := s.router.Route(rUDPAddr.IP)
src, err := s.resolveSrc(rUDPAddr.IP)
if err != nil {
return 0, fmt.Errorf("got an error while checking route, err: %w", err)
return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err)
}
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
if conn == nil {
return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP)
}
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)