Detect and capture openconnect traffic using eBPF/XDP

NobinPegasus

Member
Joined
Feb 4, 2022
Messages
50
Reaction score
0
Credits
447
I'm new to networking. I'm also learning the usage of eBPF. Currently I'm working on a project where I've to capture the inner packet of a openconnect traffic. This is my code:


xdp_dump.c
Code:
// Copyright (c) 2019 Dropbox, Inc.
// Full license can be found in the LICENSE file.

// XDP dump is simple program that dumps new IPv4 TCP connections through perf events.

#include "bpf_helpers.h"

// Ethernet header
struct ethhdr {
  __u8 h_dest[6];
  __u8 h_source[6];
  __u16 h_proto;
} __attribute__((packed));

// IPv4 header
struct iphdr {
  __u8 ihl : 4;
  __u8 version : 4;
  __u8 tos;
  __u16 tot_len;
  __u16 id;
  __u16 frag_off;
  __u8 ttl;
  __u8 protocol;
  __u16 check;
  __u32 saddr;
  __u32 daddr;
} __attribute__((packed));

// TCP header
struct tcphdr {
  __u16 source;
  __u16 dest;
  __u32 seq;
  __u32 ack_seq;
  union {
    struct {
      // Field order has been converted LittleEndiand -> BigEndian
      // in order to simplify flag checking (no need to ntohs())
      __u16 ns : 1,
      reserved : 3,
      doff : 4,
      fin : 1,
      syn : 1,
      rst : 1,
      psh : 1,
      ack : 1,
      urg : 1,
      ece : 1,
      cwr : 1;
    };
  };
  __u16 window;
  __u16 check;
  __u16 urg_ptr;
};
__attribute__((packed));

// PerfEvent eBPF map
BPF_MAP_DEF(perfmap) = {
    .map_type = BPF_MAP_TYPE_PERF_EVENT_ARRAY,
    .max_entries = 128,
};
BPF_MAP_ADD(perfmap);

// PerfEvent item
struct perf_event_item {
    struct ethhdr eth_hdr;
    struct iphdr ip_hdr;
    // __u16 source;
    // __u16 dest;
    // __u32 seq;
    // __u32 ack_seq; 
    struct tcphdr tcp_hdr;
} __attribute__((packed));

_Static_assert(sizeof(struct perf_event_item) == 54, "wrong size of perf_event_item");


// XDP program
SEC("xdp")
int xdp_dump(struct xdp_md *ctx) {
  void *data_end = (void *)(long)ctx->data_end;
  void *data = (void *)(long)ctx->data;
  __u64 packet_size = data_end - data;

  // L2
  struct ethhdr *ether = data;
  if (data + sizeof(*ether) > data_end) {
    return XDP_ABORTED;
  }

  // L3
  if (ether->h_proto != 0x08) {  // htons(ETH_P_IP) -> 0x08
    // Non IPv4
    return XDP_PASS;
  }
  data += sizeof(*ether);
  struct iphdr *ip = data;
  if (data + sizeof(*ip) > data_end) {
    return XDP_ABORTED;
  }

  data += ip->ihl * 4;
  struct tcphdr *tcp = data;
  if (data + sizeof(*tcp) > data_end) {
    return XDP_ABORTED;
  }

  // Emit perf event for every ICMP packet
  if (ip->protocol) {  // IPPROTO_TCP -> 6
    struct perf_event_item evt = {
      .eth_hdr = *ether,
      .ip_hdr = *ip,
      .tcp_hdr = *tcp,
      // .src_ip = ip->saddr,
      // .dst_ip = ip->daddr,
      // .source = tcp->source,
      // .dest = tcp->dest,
      // .seq = tcp->seq,
      // .ack_seq = tcp->ack_seq,
    };

    // flags for bpf_perf_event_output() actually contain 2 parts (each 32bit long):
    //
    // bits 0-31: either
    // - Just index in eBPF map
    // or
    // - "BPF_F_CURRENT_CPU" kernel will use current CPU_ID as eBPF map index
    //
    // bits 32-63: may be used to tell kernel to amend first N bytes
    // of original packet (ctx) to the end of the data.

    // So total perf event length will be sizeof(evt) + packet_size
    __u64 flags = BPF_F_CURRENT_CPU | (packet_size << 32);
    bpf_perf_event_output(ctx, &perfmap, flags, &evt, sizeof(evt));
  }

  return XDP_PASS;
}

char _license[] SEC("license") = "GPL";
Here's my userspace go program which captures the incoming packet on attached interface and parses the ethernet, ip, tcp header infos. It also spawns a gRPC client server to send the fetched infos.

Code:
package main

import (
    "bytes"
    "context"
    "encoding/binary"
    "encoding/hex"
    "errors"
    "flag"
    "fmt"
    "net"
    "os"
    "os/signal"

    "github.com/cilium/ebpf"
    "github.com/cilium/ebpf/perf"
    "github.com/vishvananda/netlink"
    "github.com/vishvananda/netlink/nl"

    pb "github.com/inspektors-io/grpc-nobin/grpc-test" // Update with your actual package name

    "google.golang.org/grpc"
)

//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang XdpDump ./bpf/xdp_dump.c -- -I../header

var (
    iface string
    conn  *grpc.ClientConn
)

const (
    METADATA_SIZE = 12
)

type Collect struct {
    Prog    *ebpf.Program `ebpf:"xdp_dump"`
    PerfMap *ebpf.Map     `ebpf:"perfmap"`
}

type perfEventItem struct {
    EthHdr struct {
        DestMAC   [6]uint8
        SourceMAC [6]uint8
        Proto     uint16
    }
    IpHdr struct {
        VersionIHL  byte
        TOS         byte
        TotalLen    uint16
        ID          uint16
        FragmentOff uint16
        TTL         uint8
        Protocol    uint8
        Checksum    uint16
        SrcIP       uint32
        DstIP       uint32
    }
    Tcphdr struct {
        Source uint16
        Dest   uint16
        Seq    uint32
        AckSeq uint32
        Flags  uint16 // For holding the flags field (4 bytes)
        Window uint16
        Check  uint16
        UrgPtr uint16
    }
}

func main() {
    flag.StringVar(&iface, "iface", "", "interface attached xdp program")
    flag.Parse()

    if iface == "" {
        fmt.Println("interface is not specified.")
        os.Exit(1)
    }

    link, err := netlink.LinkByName(iface)
    if err != nil {
        fmt.Printf("Failed to get interface by name: %v\n", err)
        os.Exit(1)
    }

    spec, err := LoadXdpDump()
    if err != nil {
        fmt.Printf("Failed to load XDP dump: %v\n", err)
        os.Exit(1)
    }

    var collect = &Collect{}
    if err := spec.LoadAndAssign(collect, nil); err != nil {
        fmt.Printf("Failed to load and assign XDP program: %v\n", err)
        os.Exit(1)
    }

    if err := netlink.LinkSetXdpFdWithFlags(link, collect.Prog.FD(), nl.XDP_FLAGS_SKB_MODE); err != nil {
        fmt.Printf("Failed to attach XDP program to interface: %v\n", err)
        os.Exit(1)
    }

    defer func() {
        if err := netlink.LinkSetXdpFdWithFlags(link, -1, nl.XDP_FLAGS_SKB_MODE); err != nil {
            fmt.Printf("Error detaching program: %v\n", err)
        }
    }()

    ctrlC := make(chan os.Signal, 1)
    signal.Notify(ctrlC, os.Interrupt)

    perfEvent, err := perf.NewReader(collect.PerfMap, 4096)
    if err != nil {
        fmt.Printf("Failed to create perf event reader: %v\n", err)
        os.Exit(1)
    }

    fmt.Println("All new TCP connection requests (SYN) coming to this host will be dumped here.")
    fmt.Println()

    var (
        received int = 0
        lost     int = 0
        counter  int = 0
    )

    // Connect to gRPC server
    conn, err = grpc.Dial("localhost:50051", grpc.WithInsecure())
    if err != nil {
        fmt.Printf("Failed to connect to gRPC server: %v\n", err)
        os.Exit(1)
    }
    defer conn.Close()

    // Create gRPC client
    client := pb.NewUserServiceClient(conn)

    go func() {
        var event perfEventItem
        for {
            evnt, err := perfEvent.Read()
            if err != nil {
                if errors.Is(err, perf.ErrClosed) {
                    break
                }
                fmt.Printf("Error reading perf event: %v\n", err)
                continue
            }

            reader := bytes.NewReader(evnt.RawSample)
            if err := binary.Read(reader, binary.LittleEndian, &event); err != nil {
                fmt.Printf("Error decoding perf event: %v\n", err)
                continue
            }

            // fmt.Printf("Ethernet Header:\n")
            // fmt.Printf("  Destination MAC: %02x:%02x:%02x:%02x:%02x:%02x\n", event.EthHdr.DestMAC[0], event.EthHdr.DestMAC[1], event.EthHdr.DestMAC[2], event.EthHdr.DestMAC[3], event.EthHdr.DestMAC[4], event.EthHdr.DestMAC[5])
            // fmt.Printf("  Source MAC: %02x:%02x:%02x:%02x:%02x:%02x\n", event.EthHdr.SourceMAC[0], event.EthHdr.SourceMAC[1], event.EthHdr.SourceMAC[2], event.EthHdr.SourceMAC[3], event.EthHdr.SourceMAC[4], event.EthHdr.SourceMAC[5])
            // fmt.Printf("  Protocol: %x\n", event.EthHdr.Proto)

            // fmt.Printf("IP Header:\n")
            // fmt.Printf("  Version IHL: %x\n", event.IpHdr.VersionIHL)
            // fmt.Printf("  TOS: %x\n", event.IpHdr.TOS)
            // fmt.Printf("  Total Length: %d\n", event.IpHdr.TotalLen)
            // fmt.Printf("  ID: %d\n", event.IpHdr.ID)
            // fmt.Printf("  Fragment Offset: %d\n", event.IpHdr.FragmentOff)
            // fmt.Printf("  TTL: %d\n", event.IpHdr.TTL)
            // fmt.Printf("  Protocol: %d\n", event.IpHdr.Protocol)
            // fmt.Printf("  Checksum: %d\n", event.IpHdr.Checksum)
            // fmt.Printf("  Source IP: %s\n", intToIPv4(event.IpHdr.SrcIP).String())
            // fmt.Printf("  Destination IP: %s\n", intToIPv4(event.IpHdr.DstIP).String())

            fmt.Printf("TCP Header:\n")
            // fmt.Printf("  Source Port: %d\n", ntohs(event.Tcphdr.Source))
            // fmt.Printf("  Destination Port: %d\n", ntohs(event.Tcphdr.Dest))

            // Extracting flags
            flags := extractFlags(event.Tcphdr.Flags)
            fmt.Println("Extracted Flags:")
            fmt.Println("NS:", flags["ns"])
            fmt.Println("RES:", flags["res"])
            fmt.Println("DOFF:", flags["doff"])
            fmt.Println("FIN:", flags["fin"])
            fmt.Println("SYN:", flags["syn"])
            fmt.Println("RST:", flags["rst"])
            fmt.Println("PSH:", flags["psh"])
            fmt.Println("ACK:", flags["ack"])
            fmt.Println("URG:", flags["urg"])
            fmt.Println("ECE:", flags["ece"])
            fmt.Println("CWR:", flags["cwr"])
            fmt.Printf("  Window: %d\n", event.Tcphdr.Window)
            fmt.Printf("  Checksum: %d\n", event.Tcphdr.Check)
            fmt.Printf("  Urgent Pointer: %d\n", event.Tcphdr.UrgPtr)

            counter++
            fmt.Printf("Counter: %d\n", counter)

            rawData := evnt.RawSample[METADATA_SIZE:]

            if len(evnt.RawSample)-METADATA_SIZE > 0 {
                fmt.Println(hex.Dump(evnt.RawSample[METADATA_SIZE:]))
                rawData = evnt.RawSample[METADATA_SIZE:]
            }

            received += len(evnt.RawSample)
            lost += int(evnt.LostSamples)

            // Send data to gRPC server
            err = sendDataToServer(client, int32(counter), event, rawData)
            if err != nil {
                fmt.Printf("Failed to send data to gRPC server: %v\n", err)
                continue
            }
            fmt.Println("Data sent successfully to gRPC server")

        }
    }()

    defer conn.Close()
    <-ctrlC
    perfEvent.Close()

    fmt.Println("\nSummary:")
    fmt.Printf("\t%d Event(s) Received\n", received)
    fmt.Printf("\t%d Event(s) Lost(e.g. small buffer, delays in processing)\n", lost)
    fmt.Println("\nDetaching program and exiting...")
}

func sendDataToServer(client pb.UserServiceClient, packetNumber int32, event perfEventItem, rawDumpString []byte) error {
    // Create gRPC message types for TCP, IP, and Ethernet headers
    ipHeader := &pb.IpHeader{
        SourceIp:      event.IpHdr.SrcIP,
        DestinationIp: event.IpHdr.DstIP,
        VersionIhl:    uint32(event.IpHdr.VersionIHL),
        Protocol:      uint32(event.IpHdr.Protocol),
        Check:         uint32(event.IpHdr.Checksum),
        // Ihl:           uint32(event.IPHeader.IHL),
        FragOff: uint32(event.IpHdr.FragmentOff),
        Id:      uint32(event.IpHdr.ID),
        Tos:     uint32(event.IpHdr.TOS),
        Ttl:     uint32(event.IpHdr.TTL),
        TotLen:  uint32(event.IpHdr.TotalLen),
    }
    tcpHeader := &pb.TcpHeader{
        SourcePort:      uint32(event.Tcphdr.Source),
        DestinationPort: uint32(event.Tcphdr.Dest),
        Seq:             uint32(event.Tcphdr.Seq),
        AckSeq:          uint32(event.Tcphdr.AckSeq),
        Flag:            uint32(event.Tcphdr.Flags),
        Window:          uint32(event.Tcphdr.Window),
        Check:           uint32(event.Tcphdr.Check),
        UrgPtr:          uint32(event.Tcphdr.UrgPtr),
    }
    ethernetHeader := &pb.EthernetHeader{
        EtherType:      uint32(event.EthHdr.Proto),
        DestinationMac: event.EthHdr.DestMAC[:],
        SourceMac:      event.EthHdr.SourceMAC[:],
    }

    // Convert raw binary data to hexadecimal string
    // rawDumpHex := hex.EncodeToString([]byte(rawDumpString))

    // Send data to server
    _, err := client.SendUserData(context.Background(), &pb.UserRequest{
        IpHeader:       ipHeader,
        TcpHeader:      tcpHeader,
        EthernetHeader: ethernetHeader,
        PacketNumber:   packetNumber,
        RawData:        rawDumpString, // Send hexadecimal string instead of raw binary
    })
    return err
}

func intToIPv4(ip uint32) net.IP {
    res := make([]byte, 4)
    binary.LittleEndian.PutUint32(res, ip)
    return net.IP(res)
}

func ntohs(value uint16) uint16 {
    return ((value & 0xff) << 8) | (value >> 8)
}

func extractFlags(flags uint16) map[string]uint16 {
    result := make(map[string]uint16)
    result["cwr"] = (flags >> 15) & 0x1
    result["ece"] = (flags >> 14) & 0x1
    result["urg"] = (flags >> 13) & 0x1
    result["ack"] = (flags >> 12) & 0x1
    result["psh"] = (flags >> 11) & 0x1
    result["rst"] = (flags >> 10) & 0x1
    result["syn"] = (flags >> 9) & 0x1
    result["fin"] = (flags >> 8) & 0x1
    result["doff"] = (flags >> 4) & 0xF
    result["res"] = (flags >> 1) & 0x7
    result["ns"] = flags & 0x1
    return result
}
I want to extend this program to also detect the openconnect packets and decapsulate and capture the inner ip packets destination address.
 

Members online


Top