auto-mode option (#127)

* trigger auto if connetion closes/resets after several packets

* send RST if remote connection is reset

* change cache value struct
This commit is contained in:
ruti 2024-09-14 15:47:57 +03:00
parent 193737173b
commit e02ce91363
6 changed files with 147 additions and 54 deletions

View File

@ -78,8 +78,11 @@ struct eval {
struct sockaddr_in6 in6; struct sockaddr_in6 in6;
}; };
ssize_t recv_count; ssize_t recv_count;
unsigned int round_count;
char last_round;
int attempt; int attempt;
char cache; char cache;
char mark; //
}; };
struct poolhd { struct poolhd {

142
extend.c
View File

@ -49,37 +49,29 @@ int mode_add_get(struct sockaddr_ina *dst, int m)
{ {
// m < 0: get, m > 0: set, m == 0: delete // m < 0: get, m > 0: set, m == 0: delete
assert(m >= -1 && m < params.dp_count); assert(m >= -1 && m < params.dp_count);
struct {
uint16_t port;
union {
struct in_addr i4;
struct in6_addr i6;
};
} key = { .port = dst->in.sin_port };
time_t t = 0; time_t t = 0;
struct elem *val = 0; struct elem *val = 0;
char *str = (char *)&dst->in; int len = sizeof(dst->in.sin_port);
int len = 0;
if (dst->sa.sa_family == AF_INET) { if (dst->sa.sa_family == AF_INET) {
len = sizeof(dst->in); len += sizeof(dst->in.sin_addr);
key.i4 = dst->in.sin_addr;
} }
else { else {
len = sizeof(dst->in6) - sizeof(dst->in6.sin6_scope_id); len += sizeof(dst->in6.sin6_addr);
key.i6 = dst->in6.sin6_addr;
} }
len -= sizeof(dst->sa.sa_family);
assert(len > 0);
if (m == 0) { if (m < 0) {
mem_delete(params.mempool, str, len); val = mem_get(params.mempool, (char *)&key, len);
return 0;
}
else if (m > 0) {
time(&t);
val = mem_add(params.mempool, str, len);
if (!val) {
uniperror("mem_add");
return -1;
}
val->m = m;
val->time = t;
return 0;
}
val = mem_get(params.mempool, str, len);
if (!val) { if (!val) {
return -1; return -1;
} }
@ -89,6 +81,27 @@ int mode_add_get(struct sockaddr_ina *dst, int m)
return 0; return 0;
} }
return val->m; return val->m;
}
INIT_ADDR_STR((*dst));
if (m == 0) {
LOG(LOG_S, "delete ip: %s\n", ADDR_STR);
mem_delete(params.mempool, (char *)&key, len);
return 0;
}
else {
LOG(LOG_S, "save ip: %s, m=%d\n", ADDR_STR, m);
time(&t);
val = mem_add(params.mempool, (char *)&key, len);
if (!val) {
uniperror("mem_add");
return -1;
}
val->m = m;
val->time = t;
return 0;
}
} }
@ -188,6 +201,10 @@ int on_torst(struct poolhd *pool, struct eval *val)
{ {
int m = val->pair->attempt + 1; int m = val->pair->attempt + 1;
bool can_reconn = (
val->pair->buff.data && !val->recv_count
);
if (can_reconn || params.auto_level >= 1) {
for (; m < params.dp_count; m++) { for (; m < params.dp_count; m++) {
struct desync_params *dp = &params.dp[m]; struct desync_params *dp = &params.dp[m];
if (!dp->detect) { if (!dp->detect) {
@ -198,11 +215,22 @@ int on_torst(struct poolhd *pool, struct eval *val)
} }
} }
if (m >= params.dp_count) { if (m >= params.dp_count) {
mode_add_get( if (m > 1) mode_add_get(
(struct sockaddr_ina *)&val->in6, 0); (struct sockaddr_ina *)&val->in6, 0);
}
else if (can_reconn)
return reconnect(pool, val, m);
else
mode_add_get(
(struct sockaddr_ina *)&val->in6, m);
}
struct linger l = { .l_onoff = 1 };
if (setsockopt(val->pair->fd, SOL_SOCKET,
SO_LINGER, (char *)&l, sizeof(l)) < 0) {
uniperror("setsockopt SO_LINGER");
return -1; return -1;
} }
return reconnect(pool, val, m); return -1;
} }
@ -210,21 +238,44 @@ int on_fin(struct poolhd *pool, struct eval *val)
{ {
int m = val->pair->attempt + 1; int m = val->pair->attempt + 1;
bool can_reconn = (
val->pair->buff.data && !val->recv_count
);
if (!can_reconn && params.auto_level < 1) {
return -1;
}
bool ssl_err = 0;
if (can_reconn) {
char *req = val->pair->buff.data;
ssize_t qn = val->pair->buff.size;
ssl_err = is_tls_chello(req, qn);
}
else if (val->mark && val->round_count <= 1) {
ssl_err = 1;
}
if (!ssl_err) {
return -1;
}
for (; m < params.dp_count; m++) { for (; m < params.dp_count; m++) {
struct desync_params *dp = &params.dp[m]; struct desync_params *dp = &params.dp[m];
if (!dp->detect) { if (!dp->detect) {
return -1; return -1;
} }
if (!(dp->detect & DETECT_TLS_ERR)) { if (dp->detect & DETECT_TLS_ERR) {
continue; if (can_reconn)
}
char *req = val->pair->buff.data;
ssize_t qn = val->pair->buff.size;
if (!is_tls_chello(req, qn)) {
continue;
}
return reconnect(pool, val, m); return reconnect(pool, val, m);
else {
mode_add_get(
(struct sockaddr_ina *)&val->in6, m);
return -1;
}
}
}
if (m > 1) { // delete
mode_add_get(
(struct sockaddr_ina *)&val->in6, 0);
} }
return -1; return -1;
} }
@ -279,20 +330,25 @@ int on_tunnel_check(struct poolhd *pool, struct eval *val,
assert(!out); assert(!out);
ssize_t n = recv(val->fd, buffer, bfsize, 0); ssize_t n = recv(val->fd, buffer, bfsize, 0);
if (n < 1) { if (n < 1) {
if (n) uniperror("recv"); if (!n) {
return on_fin(pool, val);
}
uniperror("recv");
switch (get_e()) { switch (get_e()) {
case ECONNRESET: case ECONNRESET:
case ECONNREFUSED: case ECONNREFUSED:
case ETIMEDOUT: case ETIMEDOUT:
return on_torst(pool, val); return on_torst(pool, val);
} }
return on_fin(pool, val); return -1;
} }
// //
if (on_response(pool, val, buffer, n) == 0) { if (on_response(pool, val, buffer, n) == 0) {
return 0; return 0;
} }
val->recv_count += n; val->recv_count += n;
val->round_count = 1;
val->last_round = 1;
struct eval *pair = val->pair; struct eval *pair = val->pair;
ssize_t sn = send(pair->fd, buffer, n, 0); ssize_t sn = send(pair->fd, buffer, n, 0);
@ -300,9 +356,12 @@ int on_tunnel_check(struct poolhd *pool, struct eval *val,
uniperror("send"); uniperror("send");
return -1; return -1;
} }
if (params.auto_level > 0 && params.dp_count > 1) {
val->mark = is_tls_chello(pair->buff.data, pair->buff.size);
}
to_tunnel(pair); to_tunnel(pair);
if (params.timeout && if (params.timeout && params.auto_level < 1 &&
set_timeout(val->fd, 0)) { set_timeout(val->fd, 0)) {
return -1; return -1;
} }
@ -315,15 +374,7 @@ int on_tunnel_check(struct poolhd *pool, struct eval *val,
if (!pair->cache) { if (!pair->cache) {
return 0; return 0;
} }
struct sockaddr_ina *addr = (struct sockaddr_ina *)&val->in6; return mode_add_get((struct sockaddr_ina *)&val->in6, m);
if (m == 0) {
LOG(LOG_S, "delete ip: m=%d\n", m);
} else {
INIT_ADDR_STR((*addr));
LOG(LOG_S, "save ip: %s, m=%d\n", ADDR_STR, m);
}
return mode_add_get(addr, m);
} }
@ -385,6 +436,7 @@ int on_desync(struct poolhd *pool, struct eval *val,
} }
val->buff.size += n; val->buff.size += n;
val->recv_count += n; val->recv_count += n;
val->round_count = 1;
val->buff.data = realloc(val->buff.data, val->buff.size); val->buff.data = realloc(val->buff.data, val->buff.size);
if (val->buff.data == 0) { if (val->buff.data == 0) {

View File

@ -19,6 +19,10 @@ int on_desync(struct poolhd *pool, struct eval *val,
ssize_t udp_hook(struct eval *val, ssize_t udp_hook(struct eval *val,
char *buffer, size_t bfsize, ssize_t n, struct sockaddr_ina *dst); char *buffer, size_t bfsize, ssize_t n, struct sockaddr_ina *dst);
int on_torst(struct poolhd *pool, struct eval *val);
int on_fin(struct poolhd *pool, struct eval *val);
#ifdef __linux__ #ifdef __linux__
int protect(int conn_fd, const char *path); int protect(int conn_fd, const char *path);
#else #else

13
main.c
View File

@ -54,7 +54,8 @@ struct params params = {
.laddr = { .laddr = {
.sin6_family = AF_INET .sin6_family = AF_INET
}, },
.debug = 0 .debug = 0,
.auto_level = 0
}; };
@ -77,6 +78,7 @@ const char help_text[] = {
#endif #endif
" -A, --auto <t,r,s,n> Try desync params after this option\n" " -A, --auto <t,r,s,n> Try desync params after this option\n"
" Detect: torst,redirect,ssl_err,none\n" " Detect: torst,redirect,ssl_err,none\n"
" -L, --auto-mode <0|1> 1 - handle trigger after several packets\n"
" -u, --cache-ttl <sec> Lifetime of cached desync params for IP\n" " -u, --cache-ttl <sec> Lifetime of cached desync params for IP\n"
#ifdef TIMEOUT_SUPPORT #ifdef TIMEOUT_SUPPORT
" -T, --timeout <sec> Timeout waiting for response, after which trigger auto\n" " -T, --timeout <sec> Timeout waiting for response, after which trigger auto\n"
@ -131,6 +133,7 @@ const struct option options[] = {
{"tfo ", 0, 0, 'F'}, {"tfo ", 0, 0, 'F'},
#endif #endif
{"auto", 1, 0, 'A'}, {"auto", 1, 0, 'A'},
{"auto-mode", 1, 0, 'L'},
{"cache-ttl", 1, 0, 'u'}, {"cache-ttl", 1, 0, 'u'},
#ifdef TIMEOUT_SUPPORT #ifdef TIMEOUT_SUPPORT
{"timeout", 1, 0, 'T'}, {"timeout", 1, 0, 'T'},
@ -561,6 +564,14 @@ int main(int argc, char **argv)
params.tfo = 1; params.tfo = 1;
break; break;
case 'L':
val = strtol(optarg, &end, 0);
if (val < 0 || val > 1 || *end)
invalid = 1;
else
params.auto_level = val;
break;
case 'A': case 'A':
dp = add((void *)&params.dp, &params.dp_count, dp = add((void *)&params.dp, &params.dp_count,
sizeof(struct desync_params)); sizeof(struct desync_params));

View File

@ -96,6 +96,7 @@ struct params {
char tfo; char tfo;
unsigned int timeout; unsigned int timeout;
int auto_level;
long cache_ttl; long cache_ttl;
char ipv6; char ipv6;
char resolve; char resolve;

28
proxy.c
View File

@ -667,11 +667,30 @@ int on_tunnel(struct poolhd *pool, struct eval *val,
if (n < 0 && get_e() == EAGAIN) { if (n < 0 && get_e() == EAGAIN) {
break; break;
} }
if (n < 1) { if (n == 0) {
if (n) uniperror("recv"); if (val->flag != FLAG_CONN)
val = val->pair;
on_fin(pool, val);
return -1;
}
if (n < 0) {
uniperror("recv");
switch (get_e()) {
case ECONNRESET:
case ETIMEDOUT:
if (val->flag == FLAG_CONN)
on_torst(pool, val);
else
on_fin(pool, val->pair);
}
return -1; return -1;
} }
val->recv_count += n; val->recv_count += n;
if (!val->last_round) {
val->round_count++;
val->last_round = 1;
pair->last_round = 0;
}
ssize_t sn = send(pair->fd, buffer, n, 0); ssize_t sn = send(pair->fd, buffer, n, 0);
if (sn != n) { if (sn != n) {
@ -893,7 +912,10 @@ static inline int on_connect(struct poolhd *pool, struct eval *val, int e)
void close_conn(struct poolhd *pool, struct eval *val) void close_conn(struct poolhd *pool, struct eval *val)
{ {
LOG(LOG_S, "close: fds=%d,%d\n", val->fd, val->pair ? val->pair->fd : -1); LOG(LOG_S, "close: fds=%d,%d, recv: %zd,%zd, rounds: %d,%d\n",
val->fd, val->pair ? val->pair->fd : -1,
val->recv_count, val->pair ? val->pair->recv_count : 0,
val->round_count, val->pair ? val->pair->round_count : 0);
del_event(pool, val); del_event(pool, val);
} }