core/socket: rework SocketPeer refcounting

Make functions and definitions that don't need to be shared local to
socket.c.
This commit is contained in:
Zbigniew Jędrzejewski-Szmek
2016-08-04 21:42:23 -04:00
parent 9a73653c3e
commit 166cf510c2
2 changed files with 100 additions and 101 deletions

View File

@@ -59,6 +59,13 @@
#include "user-util.h"
#include "in-addr-util.h"
struct SocketPeer {
unsigned n_ref;
Socket *socket;
union sockaddr_union peer;
};
static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
[SOCKET_DEAD] = UNIT_INACTIVE,
[SOCKET_START_PRE] = UNIT_ACTIVATING,
@@ -78,9 +85,6 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata);
static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata);
SocketPeer *socket_peer_new(void);
int socket_find_peer(Socket *s, int fd, SocketPeer **p);
static void socket_init(Unit *u) {
Socket *s = SOCKET(u);
@@ -482,10 +486,11 @@ static void peer_address_hash_func(const void *p, struct siphash *state) {
const SocketPeer *s = p;
assert(s);
assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
if (s->peer.sa.sa_family == AF_INET)
siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state);
else if (s->peer.sa.sa_family == AF_INET6)
else
siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state);
}
@@ -503,8 +508,7 @@ static int peer_address_compare_func(const void *a, const void *b) {
case AF_INET6:
return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr));
}
return -1;
assert_not_reached("Black sheep in the family!");
}
const struct hash_ops peer_address_hash_ops = {
@@ -537,6 +541,87 @@ static int socket_load(Unit *u) {
return socket_verify(s);
}
static SocketPeer *socket_peer_new(void) {
SocketPeer *p;
p = new0(SocketPeer, 1);
if (!p)
return NULL;
p->n_ref = 1;
return p;
}
SocketPeer *socket_peer_ref(SocketPeer *p) {
if (!p)
return NULL;
assert(p->n_ref > 0);
p->n_ref++;
return p;
}
SocketPeer *socket_peer_unref(SocketPeer *p) {
if (!p)
return NULL;
assert(p->n_ref > 0);
p->n_ref--;
if (p->n_ref > 0)
return NULL;
if (p->socket)
set_remove(p->socket->peers_by_address, p);
return mfree(p);
}
static int socket_acquire_peer(Socket *s, int fd, SocketPeer **p) {
_cleanup_(socket_peer_unrefp) SocketPeer *remote = NULL;
SocketPeer sa = {}, *i;
socklen_t salen = sizeof(sa.peer);
int r;
assert(fd >= 0);
assert(s);
r = getpeername(fd, &sa.peer.sa, &salen);
if (r < 0)
return log_error_errno(errno, "getpeername failed: %m");
if (!IN_SET(sa.peer.sa.sa_family, AF_INET, AF_INET6)) {
*p = NULL;
return 0;
}
i = set_get(s->peers_by_address, &sa);
if (i) {
*p = socket_peer_ref(i);
return 1;
}
remote = socket_peer_new();
if (!remote)
return log_oom();
remote->peer = sa.peer;
r = set_put(s->peers_by_address, remote);
if (r < 0)
return r;
remote->socket = s;
*p = remote;
remote = NULL;
return 1;
}
_const_ static const char* listen_lookup(int family, int type) {
if (family == AF_NETLINK)
@@ -2102,22 +2187,22 @@ static void socket_enter_running(Socket *s, int cfd) {
Service *service;
if (s->n_connections >= s->max_connections) {
log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.", s->n_connections);
log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.",
s->n_connections);
safe_close(cfd);
return;
}
if (s->max_connections_per_source > 0) {
r = socket_find_peer(s, cfd, &p);
r = socket_acquire_peer(s, cfd, &p);
if (r < 0) {
safe_close(cfd);
return;
}
if (p->n_ref > s->max_connections_per_source) {
log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref);
} else if (r > 0 && p->n_ref > s->max_connections_per_source) {
log_unit_warning(UNIT(s),
"Too many incoming connections (%u) from source, refusing connection attempt.",
p->n_ref);
safe_close(cfd);
p = NULL;
return;
}
}
@@ -2163,10 +2248,8 @@ static void socket_enter_running(Socket *s, int cfd) {
cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */
s->n_connections++;
if (s->max_connections_per_source > 0) {
service->peer = socket_peer_ref(p);
p = NULL;
}
service->peer = p; /* Pass ownership of the peer reference */
p = NULL;
r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL);
if (r < 0) {
@@ -2662,83 +2745,6 @@ _pure_ static bool socket_check_gc(Unit *u) {
return s->n_connections > 0;
}
SocketPeer *socket_peer_new(void) {
SocketPeer *p;
p = new0(SocketPeer, 1);
if (!p)
return NULL;
p->n_ref = 1;
return p;
}
SocketPeer *socket_peer_ref(SocketPeer *p) {
if (!p)
return NULL;
assert(p->n_ref > 0);
p->n_ref++;
return p;
}
SocketPeer *socket_peer_unref(SocketPeer *p) {
if (!p)
return NULL;
assert(p->n_ref > 0);
p->n_ref--;
if (p->n_ref > 0)
return NULL;
if (p->socket)
set_remove(p->socket->peers_by_address, p);
free(p);
return NULL;
}
int socket_find_peer(Socket *s, int fd, SocketPeer **p) {
_cleanup_free_ SocketPeer *remote = NULL;
SocketPeer sa, *i;
socklen_t salen = sizeof(sa.peer);
int r;
assert(fd >= 0);
assert(s);
r = getpeername(fd, &sa.peer.sa, &salen);
if (r < 0)
return log_error_errno(errno, "getpeername failed: %m");
i = set_get(s->peers_by_address, &sa);
if (i) {
*p = i;
return 1;
}
remote = socket_peer_new();
if (!remote)
return log_oom();
memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union));
remote->socket = s;
r = set_put(s->peers_by_address, remote);
if (r < 0)
return r;
*p = remote;
remote = NULL;
return 0;
}
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
SocketPort *p = userdata;
int cfd = -1;

View File

@@ -168,13 +168,6 @@ struct Socket {
RateLimit trigger_limit;
};
struct SocketPeer {
unsigned n_ref;
Socket *socket;
union sockaddr_union peer;
};
SocketPeer *socket_peer_ref(SocketPeer *p);
SocketPeer *socket_peer_unref(SocketPeer *p);