[client] Use fwmark-aware route lookup for raw socket UDP checksum source (#6070)

* Use fwmark-aware route lookup for raw socket UDP checksum source

* Guard nil raw socket in sharedsock WriteTo
This commit is contained in:
Viktor Liu
2026-05-05 23:18:22 +09:00
committed by GitHub
parent cd8e71002f
commit 31395f8bd2

View File

@@ -10,15 +10,13 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/mdlayher/socket" "github.com/mdlayher/socket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -37,8 +35,6 @@ type SharedSocket struct {
conn6 *socket.Conn conn6 *socket.Conn
port int port int
mtu uint16 mtu uint16
routerMux sync.RWMutex
router routing.Router
packetDemux chan rcvdPacket packetDemux chan rcvdPacket
cancel context.CancelFunc cancel context.CancelFunc
} }
@@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
} }
}() }()
rawSock.router, err = netroute.New()
if err != nil {
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
}
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
@@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
go rawSock.read(rawSock.conn6.Recvfrom) go rawSock.read(rawSock.conn6.Recvfrom)
} }
go rawSock.updateRouter()
return rawSock, nil return rawSock, nil
} }
// updateRouter updates the listener routing table client // resolveSrc returns the source IP the kernel will pick for a packet sent to
// this is needed to avoid outdated information across different client networks // dst by these raw sockets, mirroring the fwmark the kernel will see on send.
func (s *SharedSocket) updateRouter() { func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) {
ticker := time.NewTicker(15 * time.Second) opts := &netlink.RouteGetOptions{}
defer ticker.Stop() if nbnet.AdvancedRouting() {
for { opts.Mark = nbnet.ControlPlaneMark
select { }
case <-s.ctx.Done(): routes, err := netlink.RouteGetWithOptions(dst, opts)
return if err != nil {
case <-ticker.C: return nil, fmt.Errorf("route get %s: %w", dst, err)
router, err := netroute.New() }
if err != nil { for _, r := range routes {
log.Errorf("Failed to create and update packet router for stunListener: %s", err) if r.Src != nil {
continue return r.Src, nil
}
s.routerMux.Lock()
s.router = router
s.routerMux.Unlock()
} }
} }
return nil, fmt.Errorf("no source IP for %s", dst)
} }
// LocalAddr returns the local address, preferring IPv4 for backward compatibility. // LocalAddr returns the local address, preferring IPv4 for backward compatibility.
@@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
DstPort: layers.UDPPort(rUDPAddr.Port), DstPort: layers.UDPPort(rUDPAddr.Port),
} }
s.routerMux.RLock() src, err := s.resolveSrc(rUDPAddr.IP)
defer s.routerMux.RUnlock()
_, _, src, err := s.router.Route(rUDPAddr.IP)
if err != nil { if err != nil {
return 0, fmt.Errorf("got an error while checking route, err: %w", err) return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err)
} }
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
if conn == nil {
return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP)
}
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil { if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err) return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)