diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 523beb32..4855e1ae 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -10,15 +10,13 @@ import ( "context" "fmt" "net" - "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/google/gopacket/routing" - "github.com/libp2p/go-netroute" "github.com/mdlayher/socket" log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" @@ -37,8 +35,6 @@ type SharedSocket struct { conn6 *socket.Conn port int mtu uint16 - routerMux sync.RWMutex - router routing.Router packetDemux chan rcvdPacket cancel context.CancelFunc } @@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } }() - rawSock.router, err = netroute.New() - if err != nil { - return nil, fmt.Errorf("failed to create raw socket router: %w", err) - } - rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) @@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error go rawSock.read(rawSock.conn6.Recvfrom) } - go rawSock.updateRouter() - return rawSock, nil } -// updateRouter updates the listener routing table client -// this is needed to avoid outdated information across different client networks -func (s *SharedSocket) updateRouter() { - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - router, err := netroute.New() - if err != nil { - log.Errorf("Failed to create and update packet router for stunListener: %s", err) - continue - } - s.routerMux.Lock() - s.router = router - s.routerMux.Unlock() +// resolveSrc returns the source IP the kernel will pick for a packet sent to +// dst by these raw sockets, mirroring the fwmark the kernel will see on send. +func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) { + opts := &netlink.RouteGetOptions{} + if nbnet.AdvancedRouting() { + opts.Mark = nbnet.ControlPlaneMark + } + routes, err := netlink.RouteGetWithOptions(dst, opts) + if err != nil { + return nil, fmt.Errorf("route get %s: %w", dst, err) + } + for _, r := range routes { + if r.Src != nil { + return r.Src, nil } } + return nil, fmt.Errorf("no source IP for %s", dst) } // LocalAddr returns the local address, preferring IPv4 for backward compatibility. @@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { DstPort: layers.UDPPort(rUDPAddr.Port), } - s.routerMux.RLock() - defer s.routerMux.RUnlock() - - _, _, src, err := s.router.Route(rUDPAddr.IP) + src, err := s.resolveSrc(rUDPAddr.IP) if err != nil { - return 0, fmt.Errorf("got an error while checking route, err: %w", err) + return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err) } rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) + if conn == nil { + return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP) + } if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil { return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)