Refactoring on_request func

This commit is contained in:
ruti 2023-10-28 16:15:54 +02:00
parent 93fdc80191
commit 828f118afe

118
proxy.c
View File

@ -145,7 +145,17 @@ int auth_socks5(int fd, char *buffer, ssize_t n)
} }
int resp_error(int fd, int e, int flag, int re) int resp_s5_error(int fd, int e)
{
struct s5_rep s5r = {
.ver = 0x05, .code = (uint8_t )e,
.atp = S_ATP_I4
};
return send(fd, (char *)&s5r, sizeof(s5r), 0);
}
int resp_error(int fd, int e, int flag)
{ {
if (flag == FLAG_S4) { if (flag == FLAG_S4) {
struct s4_req s4r = { struct s4_req s4r = {
@ -154,76 +164,65 @@ int resp_error(int fd, int e, int flag, int re)
return send(fd, (char *)&s4r, sizeof(s4r), 0); return send(fd, (char *)&s4r, sizeof(s4r), 0);
} }
else if (flag == FLAG_S5) { else if (flag == FLAG_S5) {
uint8_t se; switch (e) {
if (re) se = (uint8_t )re; case 0: e = S_ER_OK;
else switch (e) {
case 0: se = S_ER_OK;
break; break;
case ECONNREFUSED: case ECONNREFUSED:
se = S_ER_CONN; e = S_ER_CONN;
break; break;
case EHOSTUNREACH: case EHOSTUNREACH:
case ETIMEDOUT: case ETIMEDOUT:
se = S_ER_HOST; e = S_ER_HOST;
break; break;
case ENETUNREACH: case ENETUNREACH:
se = S_ER_NET; e = S_ER_NET;
break; break;
default: se = S_ER_GEN; default: e = S_ER_GEN;
} }
struct s5_rep s5r = { return resp_s5_error(fd, e);
.ver = 0x05, .code = se,
.atp = S_ATP_I4
};
return send(fd, (char *)&s5r, sizeof(s5r), 0);
} }
return 0; return 0;
} }
int handle_socks4(int fd, char *bf, int s4_get_addr(int fd, char *bf,
size_t n, struct sockaddr_ina *dst) size_t n, struct sockaddr_ina *dst)
{ {
if (n < sizeof(struct s4_req) + 1) { if (n < sizeof(struct s4_req) + 1) {
return -1; return -1;
} }
struct s4_req *r = (struct s4_req *)bf; struct s4_req *r = (struct s4_req *)bf;
char er = 0;
if (r->cmd != S_CMD_CONN) { if (r->cmd != S_CMD_CONN) {
er = 1; return -1;
} }
else if (ntohl(r->i4.s_addr) <= 255) do { if (ntohl(r->i4.s_addr) <= 255) {
er = 1; if (!params.resolve || bf[n - 1] != 0) {
if (!params.resolve || bf[n - 1]) return -1;
break;
char *ie = strchr(bf + sizeof(*r), 0);
if (!ie)
break;
int len = (bf + n - ie) - 2;
if (len < 3)
break;
if (resolve(ie + 1, len, dst)) {
fprintf(stderr, "not resolved: %.*s\n", len, ie + 1);
break;
} }
er = 0; char *id_end = strchr(bf + sizeof(*r), 0);
} while (0); if (!id_end) {
return -1;
}
int len = (bf + n - id_end) - 2;
if (len < 3 || len > 255) {
return -1;
}
if (resolve(id_end + 1, len, dst)) {
fprintf(stderr, "not resolved: %.*s\n", len, id_end + 1);
return -1;
}
}
else { else {
dst->in.sin_family = AF_INET; dst->in.sin_family = AF_INET;
dst->in.sin_addr = r->i4; dst->in.sin_addr = r->i4;
} }
if (er) {
if (resp_error(fd, 1, FLAG_S4, 0) < 0)
perror("send");
return -1;
}
dst->in.sin_port = r->port; dst->in.sin_port = r->port;
return 0; return 0;
} }
int s_get_addr(char *buffer, ssize_t n, int s5_get_addr(char *buffer, ssize_t n,
struct sockaddr_ina *addr) struct sockaddr_ina *addr)
{ {
struct s5_req *r = (struct s5_req *)buffer; struct s5_req *r = (struct s5_req *)buffer;
@ -235,6 +234,10 @@ int s_get_addr(char *buffer, ssize_t n,
fprintf(stderr, "ss: bad request\n"); fprintf(stderr, "ss: bad request\n");
return S_ER_GEN; return S_ER_GEN;
} }
if (r->cmd != S_CMD_CONN) {
fprintf(stderr, "ss: unsupported cmd: 0x%x\n", r->cmd);
return S_ER_CMD;
}
switch (r->atp) { switch (r->atp) {
case S_ATP_I4: case S_ATP_I4:
addr->in.sin_family = AF_INET; addr->in.sin_family = AF_INET;
@ -335,7 +338,6 @@ static inline int on_request(struct poolhd *pool, struct eval *val,
char *buffer, size_t bfsize) char *buffer, size_t bfsize)
{ {
struct sockaddr_ina dst = {0}; struct sockaddr_ina dst = {0};
int error = 0, s5e = 0;
ssize_t n = recv(val->fd, buffer, bfsize, 0); ssize_t n = recv(val->fd, buffer, bfsize, 0);
if (n < 1) { if (n < 1) {
@ -354,35 +356,33 @@ static inline int on_request(struct poolhd *pool, struct eval *val,
fprintf(stderr, "ss: request to small\n"); fprintf(stderr, "ss: request to small\n");
return -1; return -1;
} }
struct s5_req *r = (struct s5_req *)buffer; int s5e = s5_get_addr(buffer, n, &dst);
if (!s5e &&
if (r->cmd != S_CMD_CONN) { create_conn(pool, val, &dst)) {
fprintf(stderr, "ss: unsupported cmd: 0x%x\n", r->cmd); s5e = S_ER_GEN;
s5e = S_ER_CMD; }
} if (s5e) {
else { resp_s5_error(val->fd, s5e);
s5e = s_get_addr(buffer, n, &dst); return -1;
if (!s5e) {
error = create_conn(pool, val, &dst);
}
} }
} }
else if (*buffer == S_VER4) { else if (*buffer == S_VER4) {
if (handle_socks4(val->fd, buffer, n, &dst)) { val->flag = FLAG_S4;
int error = s4_get_addr(val->fd, buffer, n, &dst);
if (!error) {
error = create_conn(pool, val, &dst);
}
if (error) {
if (resp_error(val->fd, error, FLAG_S4) < 0)
perror("send");
return -1; return -1;
} }
error = create_conn(pool, val, &dst);
val->flag = FLAG_S4;
} }
else { else {
fprintf(stderr, "ss: invalid version: 0x%x (%lu)\n", *buffer, n); fprintf(stderr, "ss: invalid version: 0x%x (%lu)\n", *buffer, n);
return -1; return -1;
} }
if (error || s5e) {
if (resp_error(val->fd, error ? errno : 0, val->flag, s5e) < 0)
perror("send");
return -1;
}
val->type = EV_IGNORE; val->type = EV_IGNORE;
return 0; return 0;
} }
@ -444,7 +444,7 @@ static inline int on_connect(struct poolhd *pool, struct eval *val,
} }
} }
if (resp_error(val->pair->fd, if (resp_error(val->pair->fd,
error, val->pair->flag, 0) < 0) { error, val->pair->flag) < 0) {
perror("send"); perror("send");
return -1; return -1;
} }