Merge pull request #3 from Waujito/1-tcp-segmentation

Use TCP Segmentation
This commit is contained in:
Vadim Vetrov 2024-07-23 15:23:16 -07:00 committed by GitHub
commit 5279aab69c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,9 +20,11 @@
#include <sys/socket.h> #include <sys/socket.h>
#ifndef NOUSE_GSO #ifndef NOUSE_GSO
#define USE_GSO #define USE_GSO
#endif
#ifndef USE_IP_FRAGMENTATION
#define USE_TCP_SEGMENTATION
#endif #endif
#define RAWSOCKET_MARK 0xfc70 #define RAWSOCKET_MARK 0xfc70
@ -207,8 +209,6 @@ static int ipv4_frag(struct pkt_buff *pktb, size_t payload_offset,
f2_hdr->frag_off = htons(f2_frag_off); f2_hdr->frag_off = htons(f2_frag_off);
f2_hdr->tot_len = htons(f2_dlen); f2_hdr->tot_len = htons(f2_dlen);
*frag1 = pktb_alloc(AF_INET, buff1, f1_dlen, 256);
*frag2 = pktb_alloc(AF_INET, buff2, f2_dlen, 256);
#ifdef DEBUG #ifdef DEBUG
printf("Packet split in portion %zu %zu\n", f1_dlen, f2_dlen); printf("Packet split in portion %zu %zu\n", f1_dlen, f2_dlen);
@ -217,6 +217,97 @@ static int ipv4_frag(struct pkt_buff *pktb, size_t payload_offset,
nfq_ip_set_checksum(f1_hdr); nfq_ip_set_checksum(f1_hdr);
nfq_ip_set_checksum(f2_hdr); nfq_ip_set_checksum(f2_hdr);
*frag1 = pktb_alloc(AF_INET, buff1, f1_dlen, 0);
if (*frag1 == NULL)
return -1;
*frag2 = pktb_alloc(AF_INET, buff2, f2_dlen, 0);
if (*frag2 == NULL) {
pktb_free(*frag1);
return -1;
}
return 0;
}
// split packet to two tcp-on-ipv4 segments.
static int tcp4_frag(struct pkt_buff *pktb, size_t payload_offset,
struct pkt_buff **seg1, struct pkt_buff **seg2) {
uint8_t buff1[MNL_SOCKET_BUFFER_SIZE];
uint8_t buff2[MNL_SOCKET_BUFFER_SIZE];
struct iphdr *hdr = nfq_ip_get_hdr(pktb);
size_t hdr_len = hdr->ihl * 4;
if (hdr == NULL) {errno = EINVAL; return -1;}
if (hdr->protocol != IPPROTO_TCP || !(ntohs(hdr->frag_off) & IP_DF)) {
errno = EINVAL;
return -1;
}
if (nfq_ip_set_transport_header(pktb, hdr))
return -1;
struct tcphdr *tcph = nfq_tcp_get_hdr(pktb);
size_t tcph_len = tcph->doff * 4;
if (tcph == NULL) {
errno = EINVAL;
return -1;
}
uint8_t *payload = nfq_tcp_get_payload(tcph, pktb);
size_t plen = nfq_tcp_get_payload_len(tcph, pktb);
if (hdr == NULL || payload == NULL || plen <= payload_offset) {
errno = EINVAL;
return -1;
}
size_t s1_plen = payload_offset;
size_t s1_dlen = s1_plen + hdr_len + tcph_len;
size_t s2_plen = plen - payload_offset;
size_t s2_dlen = s2_plen + hdr_len + tcph_len;
memcpy(buff1, hdr, hdr_len);
memcpy(buff2, hdr, hdr_len);
memcpy(buff1 + hdr_len, tcph, tcph_len);
memcpy(buff2 + hdr_len, tcph, tcph_len);
memcpy(buff1 + hdr_len + tcph_len, payload, s1_plen);
memcpy(buff2 + hdr_len + tcph_len, payload + payload_offset, s2_plen);
struct iphdr *s1_hdr = (void *)buff1;
struct iphdr *s2_hdr = (void *)buff2;
struct tcphdr *s1_tcph = (void *)(buff1 + hdr_len);
struct tcphdr *s2_tcph = (void *)(buff2 + hdr_len);
s1_hdr->tot_len = htons(s1_dlen);
s2_hdr->tot_len = htons(s2_dlen);
// s2_hdr->id = htons(ntohs(s1_hdr->id) + 1);
s2_tcph->seq = htonl(ntohl(s2_tcph->seq) + payload_offset);
// printf("%zu %du %du\n", payload_offset, ntohs(s1_tcph->seq), ntohs(s2_tcph->seq));
#ifdef DEBUG
printf("Packet split in portion %zu %zu\n", s1_dlen, s2_dlen);
#endif
nfq_tcp_compute_checksum_ipv4(s1_tcph, s1_hdr);
nfq_tcp_compute_checksum_ipv4(s2_tcph, s2_hdr);
*seg1 = pktb_alloc(AF_INET, buff1, s1_dlen, 0);
if (*seg1 == NULL)
return -1;
*seg2 = pktb_alloc(AF_INET, buff2, s2_dlen, 0);
if (*seg2 == NULL) {
pktb_free(*seg1);
return -1;
}
return 0; return 0;
} }
@ -231,7 +322,14 @@ static int send_raw_socket(struct pkt_buff *pktb) {
struct pkt_buff *buff1; struct pkt_buff *buff1;
struct pkt_buff *buff2; struct pkt_buff *buff2;
ipv4_frag(pktb, AVAILABLE_MTU-24, &buff1, &buff2); #ifdef USE_TCP_SEGMENTATION
if (tcp4_frag(pktb, AVAILABLE_MTU-128, &buff1, &buff2) < 0)
return -1;
#else
if (ipv4_frag(pktb, AVAILABLE_MTU-128, &buff1, &buff2) < 0)
return -1;
#endif
int sent = 0; int sent = 0;
int status = send_raw_socket(buff1); int status = send_raw_socket(buff1);
@ -242,22 +340,19 @@ static int send_raw_socket(struct pkt_buff *pktb) {
pktb_free(buff2); pktb_free(buff2);
return status; return status;
} }
pktb_free(buff1);
status = send_raw_socket(buff2); status = send_raw_socket(buff2);
if (status >= 0) sent += status; if (status >= 0) sent += status;
else { else {
pktb_free(buff1);
pktb_free(buff2); pktb_free(buff2);
return status; return status;
} }
pktb_free(buff1);
pktb_free(buff2); pktb_free(buff2);
return sent; return sent;
} }
struct iphdr *iph = nfq_ip_get_hdr(pktb); struct iphdr *iph = nfq_ip_get_hdr(pktb);
if (iph == NULL) if (iph == NULL)
return -1; return -1;
@ -272,6 +367,7 @@ static int send_raw_socket(struct pkt_buff *pktb) {
if (tcph != NULL) { if (tcph != NULL) {
sin_port = tcph->dest; sin_port = tcph->dest;
errno = 0;
} else if (udph != NULL) { } else if (udph != NULL) {
sin_port = udph->dest; sin_port = udph->dest;
} else { } else {
@ -340,25 +436,23 @@ struct verdict {
* data Payload data of TCP. * data Payload data of TCP.
* dlen Length of `data`. * dlen Length of `data`.
*/ */
static struct verdict process_tls_data( static struct verdict analyze_tls_data(
struct pkt_buff *pktb, const uint8_t *data,
struct tcphdr *tcph,
uint8_t *data,
uint32_t dlen) uint32_t dlen)
{ {
struct verdict vrd = {0}; struct verdict vrd = {0};
size_t i = 0; size_t i = 0;
uint8_t *data_end = data + dlen; const uint8_t *data_end = data + dlen;
while (i + 4 < dlen) { while (i + 4 < dlen) {
uint8_t *msgData = data + i; const uint8_t *msgData = data + i;
uint8_t tls_content_type = *msgData; uint8_t tls_content_type = *msgData;
uint8_t tls_vmajor = *(msgData + 1); uint8_t tls_vmajor = *(msgData + 1);
uint8_t tls_vminor = *(msgData + 2); uint8_t tls_vminor = *(msgData + 2);
uint16_t message_length = ntohs(*(uint16_t *)(msgData + 3)); uint16_t message_length = ntohs(*(uint16_t *)(msgData + 3));
uint8_t *message_length_ptr = msgData + 3; const uint8_t *message_length_ptr = msgData + 3;
if (i + 5 + message_length > dlen) break; if (i + 5 + message_length > dlen) break;
@ -367,7 +461,7 @@ static struct verdict process_tls_data(
goto nextMessage; goto nextMessage;
uint8_t *handshakeProto = msgData + 5; const uint8_t *handshakeProto = msgData + 5;
if (handshakeProto + 1 >= data_end) break; if (handshakeProto + 1 >= data_end) break;
@ -376,9 +470,9 @@ static struct verdict process_tls_data(
if (handshakeType != TLS_HANDSHAKE_TYPE_CLIENT_HELLO) if (handshakeType != TLS_HANDSHAKE_TYPE_CLIENT_HELLO)
goto nextMessage; goto nextMessage;
uint8_t *msgPtr = handshakeProto; const uint8_t *msgPtr = handshakeProto;
msgPtr += 1; msgPtr += 1;
uint8_t *handshakeProto_length_ptr = msgPtr + 1; const uint8_t *handshakeProto_length_ptr = msgPtr + 1;
msgPtr += 3 + 2 + 32; msgPtr += 3 + 2 + 32;
if (msgPtr + 1 >= data_end) break; if (msgPtr + 1 >= data_end) break;
@ -398,15 +492,15 @@ static struct verdict process_tls_data(
if (msgPtr + 2 >= data_end) break; if (msgPtr + 2 >= data_end) break;
uint16_t extensionsLen = ntohs(*(uint16_t *)msgPtr); uint16_t extensionsLen = ntohs(*(uint16_t *)msgPtr);
uint8_t *extensionsLen_ptr = msgPtr; const uint8_t *extensionsLen_ptr = msgPtr;
msgPtr += 2; msgPtr += 2;
uint8_t *extensionsPtr = msgPtr; const uint8_t *extensionsPtr = msgPtr;
uint8_t *extensions_end = extensionsPtr + extensionsLen; const uint8_t *extensions_end = extensionsPtr + extensionsLen;
if (extensions_end > data_end) break; if (extensions_end > data_end) break;
while (extensionsPtr < extensions_end) { while (extensionsPtr < extensions_end) {
uint8_t *extensionPtr = extensionsPtr; const uint8_t *extensionPtr = extensionsPtr;
if (extensionPtr + 4 >= extensions_end) break; if (extensionPtr + 4 >= extensions_end) break;
uint16_t extensionType = uint16_t extensionType =
@ -415,7 +509,7 @@ static struct verdict process_tls_data(
uint16_t extensionLen = uint16_t extensionLen =
ntohs(*(uint16_t *)extensionPtr); ntohs(*(uint16_t *)extensionPtr);
uint8_t *extensionLen_ptr = extensionPtr; const uint8_t *extensionLen_ptr = extensionPtr;
extensionPtr += 2; extensionPtr += 2;
@ -425,15 +519,15 @@ static struct verdict process_tls_data(
if (extensionType != TLS_EXTENSION_SNI) if (extensionType != TLS_EXTENSION_SNI)
goto nextExtension; goto nextExtension;
uint8_t *sni_ext_ptr = extensionPtr; const uint8_t *sni_ext_ptr = extensionPtr;
if (sni_ext_ptr + 2 >= extensions_end) break; if (sni_ext_ptr + 2 >= extensions_end) break;
uint16_t sni_ext_dlen = ntohs(*(uint16_t *)sni_ext_ptr); uint16_t sni_ext_dlen = ntohs(*(uint16_t *)sni_ext_ptr);
uint8_t *sni_ext_dlen_ptr = sni_ext_ptr; const uint8_t *sni_ext_dlen_ptr = sni_ext_ptr;
sni_ext_ptr += 2; sni_ext_ptr += 2;
uint8_t *sni_ext_end = sni_ext_ptr + sni_ext_dlen; const uint8_t *sni_ext_end = sni_ext_ptr + sni_ext_dlen;
if (sni_ext_end >= extensions_end) break; if (sni_ext_end >= extensions_end) break;
if (sni_ext_ptr + 3 >= sni_ext_end) break; if (sni_ext_ptr + 3 >= sni_ext_end) break;
@ -485,46 +579,34 @@ static int process_packet(const struct packet_data packet) {
} }
const int family = AF_INET; const int family = AF_INET;
const uint8_t *raw_payload = packet.payload;
size_t raw_payload_len = packet.payload_len;
struct pkt_buff *pktb = pktb_alloc( if (raw_payload == NULL) return MNL_CB_ERROR;
family,
packet.payload,
packet.payload_len,
256
);
if (pktb == NULL) { const struct iphdr *ip_header = (const void *)raw_payload;
perror("pktb_alloc");
return MNL_CB_ERROR;
}
struct iphdr *ip_header = nfq_ip_get_hdr(pktb); if (ip_header->version != IPPROTO_IPIP || ip_header->protocol != IPPROTO_TCP)
goto fallback;
if (ip_header == NULL) { int iph_len = ip_header->ihl * 4;
perror("get_ip_hdr");
goto error;
}
if (ip_header->protocol != IPPROTO_TCP) { const struct tcphdr *tcph = (const void *)(raw_payload + iph_len);
if ((const uint8_t *)tcph + 20 > raw_payload + raw_payload_len) {
printf("LZ\n");
goto fallback; goto fallback;
} }
if (nfq_ip_set_transport_header(pktb, ip_header)) { int tcph_len = tcph->doff * 4;
perror("set_transport_header\n"); if ((const uint8_t *)tcph + tcph_len > raw_payload + raw_payload_len) {
printf("LZ\n");
goto fallback; goto fallback;
} }
struct tcphdr *tcph = nfq_tcp_get_hdr(pktb); int data_len = ntohs(ip_header->tot_len) - iph_len - tcph_len;
if (tcph == NULL) { const uint8_t *data = (const uint8_t *)(raw_payload + iph_len + tcph_len);
fprintf(stderr, "tcp_get_hdr\n");
goto fallback;
}
void *data = nfq_tcp_get_payload(tcph, pktb); struct verdict vrd = analyze_tls_data(data, data_len);
ssize_t data_len = nfq_tcp_get_payload_len(tcph, pktb);
struct verdict vrd =
process_tls_data(pktb, tcph, data, data_len);
verdnlh = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, config.queue_num); verdnlh = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, config.queue_num);
nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_ACCEPT); nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_ACCEPT);
@ -539,20 +621,47 @@ static int process_packet(const struct packet_data packet) {
fprintf(stderr, "WARNING! Google video packet is too big and may cause issues!\n"); fprintf(stderr, "WARNING! Google video packet is too big and may cause issues!\n");
#endif #endif
} }
// GSO may turn kernel to not compute tcp checksum.
// Also it will never be meaningless to ensure the
// checksum is right.
nfq_tcp_compute_checksum_ipv4(tcph, ip_header);
nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_DROP);
size_t ipd_offset = ((char *)data - (char *)tcph) + vrd.sni_offset;
size_t mid_offset = ipd_offset + vrd.sni_len / 2;
mid_offset += 8 - mid_offset % 8;
struct pkt_buff *frag1; struct pkt_buff *frag1;
struct pkt_buff *frag2; struct pkt_buff *frag2;
nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_DROP);
#ifdef USE_TCP_SEGMENTATION
size_t ipd_offset = vrd.sni_offset;
size_t mid_offset = ipd_offset + vrd.sni_len / 2;
struct pkt_buff *pktb = pktb_alloc(
family,
packet.payload,
packet.payload_len,
0);
if (tcp4_frag(pktb, mid_offset, &frag1, &frag2) < 0) {
perror("tcp4_frag");
pktb_free(pktb);
goto fallback;
}
if ((send_raw_socket(frag2) == -1) || (send_raw_socket(frag1) == -1)) {
perror("raw frags send");
}
pktb_free(frag1);
pktb_free(frag2);
pktb_free(pktb);
#else
// TODO: Implement compute of tcp checksum
// GSO may turn kernel to not compute the tcp checksum.
// Also it will never be meaningless to ensure the
// checksum is right.
// nfq_tcp_compute_checksum_ipv4(tcph, ip_header);
size_t ipd_offset = ((char *)data - (char *)tcph) + vrd.sni_offset;
size_t mid_offset = ipd_offset + vrd.sni_len / 2;
mid_offset += 8 - mid_offset % 8;
if (ipv4_frag(pktb, mid_offset, &frag1, &frag2) < 0) { if (ipv4_frag(pktb, mid_offset, &frag1, &frag2) < 0) {
perror("ipv4_frag"); perror("ipv4_frag");
@ -565,10 +674,12 @@ static int process_packet(const struct packet_data packet) {
pktb_free(frag1); pktb_free(frag1);
pktb_free(frag2); pktb_free(frag2);
#endif
} }
/*
if (pktb_mangled(pktb)) { if (pktb_mangled(pktb)) {
#ifdef DEBUG #ifdef DEBUG
printf("Mangled!\n"); printf("Mangled!\n");
@ -576,6 +687,7 @@ static int process_packet(const struct packet_data packet) {
nfq_nlmsg_verdict_put_pkt( nfq_nlmsg_verdict_put_pkt(
verdnlh, pktb_data(pktb), pktb_len(pktb)); verdnlh, pktb_data(pktb), pktb_len(pktb));
} }
*/
if (mnl_socket_sendto(config.nl, verdnlh, verdnlh->nlmsg_len) < 0) { if (mnl_socket_sendto(config.nl, verdnlh, verdnlh->nlmsg_len) < 0) {
perror("mnl_socket_send"); perror("mnl_socket_send");
@ -583,14 +695,11 @@ static int process_packet(const struct packet_data packet) {
goto error; goto error;
} }
pktb_free(pktb);
return MNL_CB_OK; return MNL_CB_OK;
fallback: fallback:
pktb_free(pktb);
return fallback_accept_packet(packet.id); return fallback_accept_packet(packet.id);
error: error:
pktb_free(pktb);
return MNL_CB_ERROR; return MNL_CB_ERROR;
} }
@ -610,8 +719,6 @@ static int queue_cb(const struct nlmsghdr *nlh, void *data) {
return MNL_CB_ERROR; return MNL_CB_ERROR;
} }
ph = mnl_attr_get_payload(attr[NFQA_PACKET_HDR]); ph = mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
packet.id = ntohl(ph->packet_id); packet.id = ntohl(ph->packet_id);
@ -625,9 +732,6 @@ static int queue_cb(const struct nlmsghdr *nlh, void *data) {
return fallback_accept_packet(packet.id); return fallback_accept_packet(packet.id);
} }
uint32_t skbinfo = attr[NFQA_SKB_INFO]
? ntohl(mnl_attr_get_u32(attr[NFQA_SKB_INFO])) : 0;
if (attr[NFQA_MARK] != NULL) { if (attr[NFQA_MARK] != NULL) {
// Skip packets sent by rawsocket to escape infinity loop. // Skip packets sent by rawsocket to escape infinity loop.
if (ntohl(mnl_attr_get_u32(attr[NFQA_MARK])) == if (ntohl(mnl_attr_get_u32(attr[NFQA_MARK])) ==
@ -648,6 +752,12 @@ int main(int argc, const char *argv[])
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
#ifdef USE_TCP_SEGMENTATION
printf("Using TCP segmentation!\n");
#else
printf("Using IP fragmentation!\n");
#endif
if (open_socket()) { if (open_socket()) {
perror("Unable to open socket"); perror("Unable to open socket");
exit(EXIT_FAILURE); exit(EXIT_FAILURE);