From 2673e5f9095b2563828c1fbcbf8ba965949b7ae0 Mon Sep 17 00:00:00 2001 From: ruti <> Date: Thu, 26 Dec 2024 22:54:59 +0300 Subject: [PATCH] --ipset, optimize --hosts --- extend.c | 62 +++++++++++++++++++--------- main.c | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++----- mpool.c | 116 +++++++++++++++++++++++++++++++++++---------------- mpool.h | 25 +++++++---- params.h | 1 + 5 files changed, 253 insertions(+), 74 deletions(-) diff --git a/extend.c b/extend.c index e255d6f..e8de005 100644 --- a/extend.c +++ b/extend.c @@ -77,7 +77,7 @@ static int cache_get(const struct sockaddr_ina *dst) uint8_t key[KEY_SIZE] = { 0 }; int len = serialize_addr(dst, key, sizeof(key)); - struct elem *val = mem_get(params.mempool, (char *)key, len); + struct elem_i *val = (struct elem_i *)mem_get(params.mempool, (char *)key, len); if (!val) { return -1; } @@ -105,10 +105,17 @@ static int cache_add(const struct sockaddr_ina *dst, int m) } LOG(LOG_S, "save ip: %s, m=%d\n", ADDR_STR, m); time_t t = time(0); - - struct elem *val = mem_add(params.mempool, (char *)key, len); + + char *key_d = malloc(len); + if (!key_d) { + return -1; + } + memcpy(key_d, key, len); + + struct elem_i *val = (struct elem_i *)mem_add(params.mempool, key_d, len, sizeof(struct elem_i)); if (!val) { uniperror("mem_add"); + free(key_d); return -1; } val->m = m; @@ -116,6 +123,7 @@ static int cache_add(const struct sockaddr_ina *dst, int m) return 0; } +static bool check_l34(struct desync_params *dp, int st, const struct sockaddr_in6 *dst); int connect_hook(struct poolhd *pool, struct eval *val, const struct sockaddr_ina *dst, int next) @@ -176,19 +184,30 @@ static bool check_host( if (len <= 0) { return 0; } - char *e = host + len; - for (; host < e; host++) { - if (mem_get(hosts, host, e - host)) { - return 1; - } - if (!(host = memchr(host, '.', e - host))) { - return 0; - } + struct elem *v = mem_get(hosts, host, len); + return v && v->len <= len; +} + + +static bool check_ip( + struct mphdr *ipset, const struct sockaddr_in6 *addr) +{ + const struct sockaddr_ina *dst = (const struct sockaddr_ina *)addr; + + int len = sizeof(dst->in.sin_addr); + char *data = (char *)&dst->in.sin_addr; + + if (dst->sa.sa_family == AF_INET6) { + len = sizeof(dst->in6.sin6_addr); + data = (char *)&dst->in6.sin6_addr; + } + if (mem_get(ipset, data, len * 8)) { + return 1; } return 0; } - + static bool check_proto_tcp(int proto, const char *buffer, ssize_t n) { if (!(proto & ~IS_IPV4)) { @@ -206,12 +225,12 @@ static bool check_proto_tcp(int proto, const char *buffer, ssize_t n) } -static bool check_l34(int proto, const uint16_t *pf, int st, const struct sockaddr_in6 *dst) +static bool check_l34(struct desync_params *dp, int st, const struct sockaddr_in6 *dst) { - if ((proto & IS_UDP) && (st != SOCK_DGRAM)) { + if ((dp->proto & IS_UDP) && (st != SOCK_DGRAM)) { return 0; } - if (proto & IS_IPV4) { + if (dp->proto & IS_IPV4) { static const char *pat = "\0\0\0\0\0\0\0\0\0\0\xff\xff"; if (dst->sin6_family != AF_INET @@ -219,8 +238,11 @@ static bool check_l34(int proto, const uint16_t *pf, int st, const struct sockad return 0; } } - if (pf[0] && - (dst->sin6_port < pf[0] || dst->sin6_port > pf[1])) { + if (dp->pf[0] && + (dst->sin6_port < dp->pf[0] || dst->sin6_port > dp->pf[1])) { + return 0; + } + if (dp->ipset && !check_ip(dp->ipset, dst)) { return 0; } return 1; @@ -340,8 +362,8 @@ static int setup_conn(struct eval *client, const char *buffer, ssize_t n) if (!m) for (; m < params.dp_count; m++) { struct desync_params *dp = ¶ms.dp[m]; if (!dp->detect - && (check_l34(dp->proto, dp->pf, SOCK_STREAM, &client->pair->in6) - && check_proto_tcp(dp->proto, buffer, n)) + && check_l34(dp, SOCK_STREAM, &client->pair->in6) + && check_proto_tcp(dp->proto, buffer, n) && (!dp->hosts || check_host(dp->hosts, buffer, n))) { break; } @@ -544,7 +566,7 @@ ssize_t udp_hook(struct eval *val, for (; m < params.dp_count; m++) { struct desync_params *dp = ¶ms.dp[m]; if (!dp->detect - && check_l34(dp->proto, dp->pf, SOCK_DGRAM, &dst->in6)) { + && check_l34(dp, SOCK_DGRAM, &dst->in6)) { break; } } diff --git a/main.c b/main.c index ec54857..1822fef 100644 --- a/main.c +++ b/main.c @@ -90,6 +90,7 @@ const static char help_text[] = { #endif " -K, --proto Protocol whitelist: tls,http,udp,ipv4\n" " -H, --hosts Hosts whitelist, filename or :string\n" + " -j, --ipset IP whitelist\n" " -V, --pf Ports range whitelist\n" " -R, --round Number of request to which desync will be applied\n" " -s, --split Position format: offset[:repeats:skip][+flag1[flag2]]\n" @@ -177,6 +178,7 @@ const struct option options[] = { {"drop-sack", 0, 0, 'Y'}, {"protect-path", 1, 0, 'P'}, // #endif + {"ipset", 1, 0, 'j'}, {0} }; @@ -297,7 +299,71 @@ static inline int lower_char(char *cl) struct mphdr *parse_hosts(char *buffer, size_t size) { - struct mphdr *hdr = mem_pool(1); + struct mphdr *hdr = mem_pool(1, CMP_HOST); + if (!hdr) { + return 0; + } + size_t num = 0; + bool drop = 0; + char *end = buffer + size; + char *e = buffer, *s = buffer; + + for (; e <= end; e++) { + if (e != end && *e != ' ' && *e != '\n' && *e != '\r') { + if (lower_char(e)) { + drop = 1; + } + continue; + } + if (s == e) { + s++; + continue; + } + if (!drop) { + if (!mem_add(hdr, s, e - s, sizeof(struct elem))) { + mem_destroy(hdr); + return 0; + } + } + else { + LOG(LOG_E, "invalid host: num: %zd \"%.*s\"\n", num + 1, (int )(e - s), s); + drop = 0; + } + num++; + s = e + 1; + } + LOG(LOG_S, "hosts count: %zd\n", hdr->count); + return hdr; +} + + +static int parse_ip(char *out, char *str, size_t size) +{ + long bits = 0; + char *sep = memchr(str, '/', size); + if (sep) { + bits = strtol(sep + 1, 0, 10); + if (bits <= 0) { + return 0; + } + *sep = 0; + } + int len = 4; + + if (inet_pton(AF_INET, str, out) <= 0) { + if (inet_pton(AF_INET6, str, out) <= 0) { + return 0; + } + else len = 16; + } + if (!bits || bits > len * 8) bits = len * 8; + return (int )bits; +} + + +struct mphdr *parse_ipset(char *buffer, size_t size) +{ + struct mphdr *hdr = mem_pool(0, CMP_BITS); if (!hdr) { return 0; } @@ -307,22 +373,34 @@ struct mphdr *parse_hosts(char *buffer, size_t size) for (; e <= end; e++) { if (e != end && *e != ' ' && *e != '\n' && *e != '\r') { - if (lower_char(e)) { - LOG(LOG_E, "invalid host: num: %zd (%.*s)\n", num + 1, (int )(e - s + 1), s); - } continue; } if (s == e) { s++; continue; } - if (mem_add(hdr, s, e - s) == 0) { - free(hdr); - return 0; - } + char ip[e - s + 1]; + ip[e - s] = 0; + memcpy(ip, s, e - s); + num++; s = e + 1; + + char *ip_raw = malloc(16); + int bits = parse_ip(ip_raw, ip, sizeof(ip)); + if (bits <= 0) { + LOG(LOG_E, "invalid ip: num: %zd\n", num); + free(ip_raw); + continue; + } + struct elem *elem = mem_add(hdr, ip_raw, bits, sizeof(struct elem)); + if (!elem) { + free(ip_raw); + mem_destroy(hdr); + return 0; + } } + LOG(LOG_S, "ip count: %zd\n", hdr->count); return hdr; } @@ -514,6 +592,10 @@ void clear_params(void) mem_destroy(s.hosts); s.hosts = 0; } + if (s.ipset != 0) { + mem_destroy(s.ipset); + s.hosts = 0; + } } free(params.dp); params.dp = 0; @@ -666,7 +748,7 @@ int main(int argc, char **argv) break; case 'A': - if (!(dp->hosts || dp->proto || dp->pf[0] || dp->detect)) { + if (!(dp->hosts || dp->proto || dp->pf[0] || dp->detect || dp->ipset)) { all_limited = 0; } dp = add((void *)¶ms.dp, ¶ms.dp_count, @@ -760,12 +842,31 @@ int main(int argc, char **argv) } dp->hosts = parse_hosts(dp->file_ptr, dp->file_size); if (!dp->hosts) { - perror("parse_hosts"); + uniperror("parse_hosts"); clear_params(); return -1; } break; + case 'j':; + if (dp->ipset) { + continue; + } + ssize_t size; + char *data = ftob(optarg, &size); + if (!data) { + uniperror("read/parse"); + invalid = 1; + continue; + } + dp->ipset = parse_ipset(data, size); + if (!dp->ipset) { + uniperror("parse_ipset"); + invalid = 1; + } + free(data); + break; + case 's': case 'd': case 'o': @@ -997,7 +1098,7 @@ int main(int argc, char **argv) return -1; } } - params.mempool = mem_pool(0); + params.mempool = mem_pool(0, CMP_BYTES); if (!params.mempool) { uniperror("mem_pool"); clear_params(); diff --git a/mpool.c b/mpool.c index 50fe43e..01e1af4 100644 --- a/mpool.c +++ b/mpool.c @@ -1,86 +1,130 @@ #include "mpool.h" +#include #include #include +#include - -static inline int scmp(const struct elem *p, const struct elem *q) +static int bit_cmp(const struct elem *p, const struct elem *q) { - if (p->len != q ->len) { + int len = q->len < p->len ? q->len : p->len; + int df = len % 8, bytes = len / 8; + int cmp = memcmp(p->data, q->data, bytes); + + if (cmp || !df) { + return cmp; + } + uint8_t c1 = p->data[bytes] >> (8 - df); + uint8_t c2 = q->data[bytes] >> (8 - df); + if (c1 != c2) { + if (c1 < c2) return -1; + else return 1; + } + return 0; +} + +static int byte_cmp(const struct elem *p, const struct elem *q) +{ + if (p->len != q->len) { return p->len < q->len ? -1 : 1; } return memcmp(p->data, q->data, p->len); } + +static int host_cmp(const struct elem *p, const struct elem *q) +{ + int len = q->len < p->len ? q->len : p->len; + char *pd = p->data + p->len, *qd = q->data + q->len; + while (len-- > 0) { + if (*--pd != *--qd) { + return *pd < *qd ? -1 : 1; + } + } + if (p->len == q->len + || (p->len > q->len ? pd[-1] : qd[-1]) == '.') + return 0; + + return p->len > q->len ? 1 : -1; +} + +static int scmp(const struct elem *p, const struct elem *q) +{ + switch (p->cmp_type) { + case CMP_BITS: + return bit_cmp(p, q); + case CMP_HOST: + return host_cmp(p, q); + default: + return byte_cmp(p, q); + } +} + KAVL_INIT(my, struct elem, head, scmp) -struct mphdr *mem_pool(bool cst) +struct mphdr *mem_pool(bool is_static, unsigned char cmp_type) { struct mphdr *hdr = calloc(sizeof(struct mphdr), 1); if (hdr) { - hdr->stat = cst; + hdr->static_data = is_static; + hdr->cmp_type = cmp_type; } return hdr; } -struct elem *mem_get(struct mphdr *hdr, char *str, int len) +struct elem *mem_get(const struct mphdr *hdr, const char *str, int len) { - struct { - int len; - char *data; - } temp = { .len = len, .data = str }; - - return kavl_find(my, hdr->root, (struct elem *)&temp, 0); + struct elem temp = { + .cmp_type = hdr->cmp_type, + .len = len, .data = (char *)str + }; + return kavl_find(my, hdr->root, &temp, 0); } -struct elem *mem_add(struct mphdr *hdr, char *str, int len) +struct elem *mem_add(struct mphdr *hdr, char *str, int len, size_t struct_size) { - struct elem *v, *e = calloc(sizeof(struct elem), 1); + struct elem *v, *e = calloc(struct_size, 1); if (!e) { return 0; } e->len = len; - if (!hdr->stat) { - e->data = malloc(len); - if (!e->data) { - free(e); - return 0; - } - memcpy(e->data, str, len); - } - else { - e->data = str; - } + e->cmp_type = hdr->cmp_type; + e->data = str; + v = kavl_insert(my, &hdr->root, e, 0); + while (e != v && e->len < v->len) { + mem_delete(hdr, v->data, v->len); + v = kavl_insert(my, &hdr->root, e, 0); + } if (e != v) { - if (!hdr->stat) { + if (!hdr->static_data) free(e->data); - } free(e); } + else hdr->count++; return v; } -void mem_delete(struct mphdr *hdr, char *str, int len) +void mem_delete(struct mphdr *hdr, const char *str, int len) { - struct { - int len; - char *data; - } temp = { .len = len, .data = str }; - - struct elem *e = kavl_erase(my, &hdr->root, (struct elem *)&temp, 0); + struct elem temp = { + .cmp_type = hdr->cmp_type, + .len = len, .data = (char *)str + }; + struct elem *e = kavl_erase(my, &hdr->root, &temp, 0); if (!e) { return; } - if (!hdr->stat) { + if (!hdr->static_data) { free(e->data); e->data = 0; } free(e); + hdr->count--; } @@ -91,7 +135,7 @@ void mem_destroy(struct mphdr *hdr) if (!e) { break; } - if (!hdr->stat && e->data) { + if (!hdr->static_data) { free(e->data); } e->data = 0; diff --git a/mpool.h b/mpool.h index c209dc0..440219e 100644 --- a/mpool.h +++ b/mpool.h @@ -5,26 +5,37 @@ #include #include "kavl.h" +#define CMP_BYTES 0 +#define CMP_BITS 1 +#define CMP_HOST 2 + struct elem { int len; char *data; - int m; - time_t time; + unsigned char cmp_type; KAVL_HEAD(struct elem) head; }; +struct elem_i { + struct elem i; + int m; + time_t time; +}; + struct mphdr { - bool stat; + bool static_data; + unsigned char cmp_type; + size_t count; struct elem *root; }; -struct mphdr *mem_pool(bool cst); +struct mphdr *mem_pool(bool is_static, unsigned char cmp_type); -struct elem *mem_get(struct mphdr *hdr, char *str, int len); +struct elem *mem_get(const struct mphdr *hdr, const char *str, int len); -struct elem *mem_add(struct mphdr *hdr, char *str, int len); +struct elem *mem_add(struct mphdr *hdr, char *str, int len, size_t ssize); -void mem_delete(struct mphdr *hdr, char *str, int len); +void mem_delete(struct mphdr *hdr, const char *str, int len); void mem_destroy(struct mphdr *hdr); diff --git a/params.h b/params.h index 902b384..a43d336 100644 --- a/params.h +++ b/params.h @@ -88,6 +88,7 @@ struct desync_params { int proto; int detect; struct mphdr *hosts; + struct mphdr *ipset; uint16_t pf[2]; int rounds[2];