[client] Use fwmark-aware route lookup for raw socket UDP checksum source (#6070)
* Use fwmark-aware route lookup for raw socket UDP checksum source * Guard nil raw socket in sharedsock WriteTo
This commit is contained in:
@@ -10,15 +10,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/gopacket/routing"
|
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
"github.com/mdlayher/socket"
|
"github.com/mdlayher/socket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
@@ -37,8 +35,6 @@ type SharedSocket struct {
|
|||||||
conn6 *socket.Conn
|
conn6 *socket.Conn
|
||||||
port int
|
port int
|
||||||
mtu uint16
|
mtu uint16
|
||||||
routerMux sync.RWMutex
|
|
||||||
router routing.Router
|
|
||||||
packetDemux chan rcvdPacket
|
packetDemux chan rcvdPacket
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
@@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rawSock.router, err = netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
|
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
|
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
|
||||||
@@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
|
|||||||
go rawSock.read(rawSock.conn6.Recvfrom)
|
go rawSock.read(rawSock.conn6.Recvfrom)
|
||||||
}
|
}
|
||||||
|
|
||||||
go rawSock.updateRouter()
|
|
||||||
|
|
||||||
return rawSock, nil
|
return rawSock, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateRouter updates the listener routing table client
|
// resolveSrc returns the source IP the kernel will pick for a packet sent to
|
||||||
// this is needed to avoid outdated information across different client networks
|
// dst by these raw sockets, mirroring the fwmark the kernel will see on send.
|
||||||
func (s *SharedSocket) updateRouter() {
|
func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) {
|
||||||
ticker := time.NewTicker(15 * time.Second)
|
opts := &netlink.RouteGetOptions{}
|
||||||
defer ticker.Stop()
|
if nbnet.AdvancedRouting() {
|
||||||
for {
|
opts.Mark = nbnet.ControlPlaneMark
|
||||||
select {
|
}
|
||||||
case <-s.ctx.Done():
|
routes, err := netlink.RouteGetWithOptions(dst, opts)
|
||||||
return
|
if err != nil {
|
||||||
case <-ticker.C:
|
return nil, fmt.Errorf("route get %s: %w", dst, err)
|
||||||
router, err := netroute.New()
|
}
|
||||||
if err != nil {
|
for _, r := range routes {
|
||||||
log.Errorf("Failed to create and update packet router for stunListener: %s", err)
|
if r.Src != nil {
|
||||||
continue
|
return r.Src, nil
|
||||||
}
|
|
||||||
s.routerMux.Lock()
|
|
||||||
s.router = router
|
|
||||||
s.routerMux.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil, fmt.Errorf("no source IP for %s", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the local address, preferring IPv4 for backward compatibility.
|
// LocalAddr returns the local address, preferring IPv4 for backward compatibility.
|
||||||
@@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
|||||||
DstPort: layers.UDPPort(rUDPAddr.Port),
|
DstPort: layers.UDPPort(rUDPAddr.Port),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.routerMux.RLock()
|
src, err := s.resolveSrc(rUDPAddr.IP)
|
||||||
defer s.routerMux.RUnlock()
|
|
||||||
|
|
||||||
_, _, src, err := s.router.Route(rUDPAddr.IP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("got an error while checking route, err: %w", err)
|
return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
|
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
|
||||||
|
if conn == nil {
|
||||||
|
return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP)
|
||||||
|
}
|
||||||
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
|
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
|
||||||
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)
|
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user