diff --git a/youtubeUnblock.c b/youtubeUnblock.c index a194bf8..2fd334b 100644 --- a/youtubeUnblock.c +++ b/youtubeUnblock.c @@ -217,8 +217,8 @@ static int ipv4_frag(struct pkt_buff *pktb, size_t payload_offset, nfq_ip_set_checksum(f1_hdr); nfq_ip_set_checksum(f2_hdr); - *frag1 = pktb_alloc(AF_INET, buff1, f1_dlen, 256); - *frag2 = pktb_alloc(AF_INET, buff2, f2_dlen, 256); + *frag1 = pktb_alloc(AF_INET, buff1, f1_dlen, 0); + *frag2 = pktb_alloc(AF_INET, buff2, f2_dlen, 0); return 0; } @@ -290,8 +290,8 @@ static int tcp4_frag(struct pkt_buff *pktb, size_t payload_offset, nfq_tcp_compute_checksum_ipv4(s1_tcph, s1_hdr); nfq_tcp_compute_checksum_ipv4(s2_tcph, s2_hdr); - *seg1 = pktb_alloc(AF_INET, buff1, s1_dlen, 256); - *seg2 = pktb_alloc(AF_INET, buff2, s2_dlen, 256); + *seg1 = pktb_alloc(AF_INET, buff1, s1_dlen, 0); + *seg2 = pktb_alloc(AF_INET, buff2, s2_dlen, 0); return 0; } @@ -418,25 +418,23 @@ struct verdict { * data Payload data of TCP. * dlen Length of `data`. */ -static struct verdict process_tls_data( - struct pkt_buff *pktb, - struct tcphdr *tcph, - uint8_t *data, +static struct verdict analyze_tls_data( + const uint8_t *data, uint32_t dlen) { struct verdict vrd = {0}; size_t i = 0; - uint8_t *data_end = data + dlen; + const uint8_t *data_end = data + dlen; while (i + 4 < dlen) { - uint8_t *msgData = data + i; + const uint8_t *msgData = data + i; uint8_t tls_content_type = *msgData; uint8_t tls_vmajor = *(msgData + 1); uint8_t tls_vminor = *(msgData + 2); uint16_t message_length = ntohs(*(uint16_t *)(msgData + 3)); - uint8_t *message_length_ptr = msgData + 3; + const uint8_t *message_length_ptr = msgData + 3; if (i + 5 + message_length > dlen) break; @@ -445,7 +443,7 @@ static struct verdict process_tls_data( goto nextMessage; - uint8_t *handshakeProto = msgData + 5; + const uint8_t *handshakeProto = msgData + 5; if (handshakeProto + 1 >= data_end) break; @@ -454,9 +452,9 @@ static struct verdict process_tls_data( if (handshakeType != TLS_HANDSHAKE_TYPE_CLIENT_HELLO) goto nextMessage; - uint8_t *msgPtr = handshakeProto; + const uint8_t *msgPtr = handshakeProto; msgPtr += 1; - uint8_t *handshakeProto_length_ptr = msgPtr + 1; + const uint8_t *handshakeProto_length_ptr = msgPtr + 1; msgPtr += 3 + 2 + 32; if (msgPtr + 1 >= data_end) break; @@ -476,15 +474,15 @@ static struct verdict process_tls_data( if (msgPtr + 2 >= data_end) break; uint16_t extensionsLen = ntohs(*(uint16_t *)msgPtr); - uint8_t *extensionsLen_ptr = msgPtr; + const uint8_t *extensionsLen_ptr = msgPtr; msgPtr += 2; - uint8_t *extensionsPtr = msgPtr; - uint8_t *extensions_end = extensionsPtr + extensionsLen; + const uint8_t *extensionsPtr = msgPtr; + const uint8_t *extensions_end = extensionsPtr + extensionsLen; if (extensions_end > data_end) break; while (extensionsPtr < extensions_end) { - uint8_t *extensionPtr = extensionsPtr; + const uint8_t *extensionPtr = extensionsPtr; if (extensionPtr + 4 >= extensions_end) break; uint16_t extensionType = @@ -493,7 +491,7 @@ static struct verdict process_tls_data( uint16_t extensionLen = ntohs(*(uint16_t *)extensionPtr); - uint8_t *extensionLen_ptr = extensionPtr; + const uint8_t *extensionLen_ptr = extensionPtr; extensionPtr += 2; @@ -503,15 +501,15 @@ static struct verdict process_tls_data( if (extensionType != TLS_EXTENSION_SNI) goto nextExtension; - uint8_t *sni_ext_ptr = extensionPtr; + const uint8_t *sni_ext_ptr = extensionPtr; if (sni_ext_ptr + 2 >= extensions_end) break; uint16_t sni_ext_dlen = ntohs(*(uint16_t *)sni_ext_ptr); - uint8_t *sni_ext_dlen_ptr = sni_ext_ptr; + const uint8_t *sni_ext_dlen_ptr = sni_ext_ptr; sni_ext_ptr += 2; - uint8_t *sni_ext_end = sni_ext_ptr + sni_ext_dlen; + const uint8_t *sni_ext_end = sni_ext_ptr + sni_ext_dlen; if (sni_ext_end >= extensions_end) break; if (sni_ext_ptr + 3 >= sni_ext_end) break; @@ -563,46 +561,34 @@ static int process_packet(const struct packet_data packet) { } const int family = AF_INET; + const uint8_t *raw_payload = packet.payload; + size_t raw_payload_len = packet.payload_len; - struct pkt_buff *pktb = pktb_alloc( - family, - packet.payload, - packet.payload_len, - 256 - ); + if (raw_payload == NULL) return MNL_CB_ERROR; - if (pktb == NULL) { - perror("pktb_alloc"); - return MNL_CB_ERROR; - } + const struct iphdr *ip_header = (const void *)raw_payload; - struct iphdr *ip_header = nfq_ip_get_hdr(pktb); + if (ip_header->version != IPPROTO_IPIP || ip_header->protocol != IPPROTO_TCP) + goto fallback; - if (ip_header == NULL) { - perror("get_ip_hdr"); - goto error; - } + int iph_len = ip_header->ihl * 4; - if (ip_header->protocol != IPPROTO_TCP) { + const struct tcphdr *tcph = (const void *)(raw_payload + iph_len); + if ((const uint8_t *)tcph + 20 > raw_payload + raw_payload_len) { + printf("LZ\n"); goto fallback; } - if (nfq_ip_set_transport_header(pktb, ip_header)) { - perror("set_transport_header\n"); + int tcph_len = tcph->doff * 4; + if ((const uint8_t *)tcph + tcph_len > raw_payload + raw_payload_len) { + printf("LZ\n"); goto fallback; } - struct tcphdr *tcph = nfq_tcp_get_hdr(pktb); - if (tcph == NULL) { - fprintf(stderr, "tcp_get_hdr\n"); - goto fallback; - } + int data_len = ntohs(ip_header->tot_len) - iph_len - tcph_len; + const uint8_t *data = (const uint8_t *)(raw_payload + iph_len + tcph_len); - void *data = nfq_tcp_get_payload(tcph, pktb); - ssize_t data_len = nfq_tcp_get_payload_len(tcph, pktb); - - struct verdict vrd = - process_tls_data(pktb, tcph, data, data_len); + struct verdict vrd = analyze_tls_data(data, data_len); verdnlh = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, config.queue_num); nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_ACCEPT); @@ -617,22 +603,25 @@ static int process_packet(const struct packet_data packet) { fprintf(stderr, "WARNING! Google video packet is too big and may cause issues!\n"); #endif } - // GSO may turn kernel to not compute tcp checksum. - // Also it will never be meaningless to ensure the - // checksum is right. - nfq_tcp_compute_checksum_ipv4(tcph, ip_header); - nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_DROP); - + struct pkt_buff *frag1; struct pkt_buff *frag2; + nfq_nlmsg_verdict_put(verdnlh, packet.id, NF_DROP); #ifdef USE_TCP_SEGMENTATION size_t ipd_offset = vrd.sni_offset; size_t mid_offset = ipd_offset + vrd.sni_len / 2; + + struct pkt_buff *pktb = pktb_alloc( + family, + packet.payload, + packet.payload_len, + 0); if (tcp4_frag(pktb, mid_offset, &frag1, &frag2) < 0) { perror("tcp4_frag"); + pktb_free(pktb); goto fallback; } @@ -642,8 +631,15 @@ static int process_packet(const struct packet_data packet) { pktb_free(frag1); pktb_free(frag2); + pktb_free(pktb); #else + // TODO: Implement compute of tcp checksum + // GSO may turn kernel to not compute the tcp checksum. + // Also it will never be meaningless to ensure the + // checksum is right. + // nfq_tcp_compute_checksum_ipv4(tcph, ip_header); + size_t ipd_offset = ((char *)data - (char *)tcph) + vrd.sni_offset; size_t mid_offset = ipd_offset + vrd.sni_len / 2; mid_offset += 8 - mid_offset % 8; @@ -666,7 +662,7 @@ static int process_packet(const struct packet_data packet) { } - +/* if (pktb_mangled(pktb)) { #ifdef DEBUG printf("Mangled!\n"); @@ -674,6 +670,7 @@ static int process_packet(const struct packet_data packet) { nfq_nlmsg_verdict_put_pkt( verdnlh, pktb_data(pktb), pktb_len(pktb)); } +*/ if (mnl_socket_sendto(config.nl, verdnlh, verdnlh->nlmsg_len) < 0) { perror("mnl_socket_send"); @@ -681,14 +678,11 @@ static int process_packet(const struct packet_data packet) { goto error; } - pktb_free(pktb); return MNL_CB_OK; fallback: - pktb_free(pktb); return fallback_accept_packet(packet.id); error: - pktb_free(pktb); return MNL_CB_ERROR; } @@ -708,8 +702,6 @@ static int queue_cb(const struct nlmsghdr *nlh, void *data) { return MNL_CB_ERROR; } - - ph = mnl_attr_get_payload(attr[NFQA_PACKET_HDR]); packet.id = ntohl(ph->packet_id); @@ -723,9 +715,6 @@ static int queue_cb(const struct nlmsghdr *nlh, void *data) { return fallback_accept_packet(packet.id); } - uint32_t skbinfo = attr[NFQA_SKB_INFO] - ? ntohl(mnl_attr_get_u32(attr[NFQA_SKB_INFO])) : 0; - if (attr[NFQA_MARK] != NULL) { // Skip packets sent by rawsocket to escape infinity loop. if (ntohl(mnl_attr_get_u32(attr[NFQA_MARK])) ==