提交 fd1914b2 编写于 作者: D David S. Miller

Merge branch 'tcp-fast-so_reuseport'

Craig Gallek says:

====================
Faster SO_REUSEPORT for TCP

This patch series complements an earlier series (6a5ef90c)
which added faster SO_REUSEPORT lookup for UDP sockets by
extending the feature to TCP sockets.  It uses the same
array-based data structure which allows for socket selection
after finding the first listening socket that matches an incoming
packet.  Prior to this feature, every socket in the reuseport
group needed to be found and examined before a selection could be
made.

With this series the SO_ATTACH_REUSEPORT_CBPF and
SO_ATTACH_REUSEPORT_EBPF socket options now work for TCP sockets
as well.  The test at the end of the series includes an example of
how to use these options to select a reuseport socket based on the
cpu core id handling the incoming packet.

There are several refactoring patches that precede the feature
implementation.  Only the last two patches in this series
should result in any behavioral changes.

v4
- Fix build issue when compiling IPv6 as a module.  This required
  moving the ipv6_rcv_saddr_equal into an object that is included as a
  built-in object.  I included this change in the second patch which
  adds inet6_hash since that is where ipv6_rcv_saddr_equal will
  later be called from non-module code.

v3:
- Another warning in the first patch caught by a build bot.  Return 0 in
  the no-op UDP hash function.

v2:
- In the first patched I missed a couple of hash functions that should now be
  returning int instead of void.  I missed these the first time through as it
  only generated a warning and not an error :\
====================
Signed-off-by: NDavid S. Miller <davem@davemloft.net>
......@@ -29,9 +29,14 @@ static inline struct tcphdr *tcp_hdr(const struct sk_buff *skb)
return (struct tcphdr *)skb_transport_header(skb);
}
static inline unsigned int __tcp_hdrlen(const struct tcphdr *th)
{
return th->doff * 4;
}
static inline unsigned int tcp_hdrlen(const struct sk_buff *skb)
{
return tcp_hdr(skb)->doff * 4;
return __tcp_hdrlen(tcp_hdr(skb));
}
static inline struct tcphdr *inner_tcp_hdr(const struct sk_buff *skb)
......
......@@ -87,6 +87,8 @@ int __ipv6_get_lladdr(struct inet6_dev *idev, struct in6_addr *addr,
u32 banned_flags);
int ipv6_get_lladdr(struct net_device *dev, struct in6_addr *addr,
u32 banned_flags);
int ipv4_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard);
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard);
void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr);
......
......@@ -53,6 +53,7 @@ struct sock *__inet6_lookup_established(struct net *net,
struct sock *inet6_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
const __be16 sport,
const struct in6_addr *daddr,
......@@ -60,6 +61,7 @@ struct sock *inet6_lookup_listener(struct net *net,
static inline struct sock *__inet6_lookup(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
const __be16 sport,
const struct in6_addr *daddr,
......@@ -71,12 +73,12 @@ static inline struct sock *__inet6_lookup(struct net *net,
if (sk)
return sk;
return inet6_lookup_listener(net, hashinfo, saddr, sport,
return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
daddr, hnum, dif);
}
static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb,
struct sk_buff *skb, int doff,
const __be16 sport,
const __be16 dport,
int iif)
......@@ -86,16 +88,19 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
if (sk)
return sk;
return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo,
&ipv6_hdr(skb)->saddr, sport,
return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
doff, &ipv6_hdr(skb)->saddr, sport,
&ipv6_hdr(skb)->daddr, ntohs(dport),
iif);
}
struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr, const __be16 sport,
const struct in6_addr *daddr, const __be16 dport,
const int dif);
int inet6_hash(struct sock *sk);
#endif /* IS_ENABLED(CONFIG_IPV6) */
#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif) \
......
......@@ -207,12 +207,16 @@ void inet_hashinfo_init(struct inet_hashinfo *h);
bool inet_ehash_insert(struct sock *sk, struct sock *osk);
bool inet_ehash_nolisten(struct sock *sk, struct sock *osk);
void __inet_hash(struct sock *sk, struct sock *osk);
void inet_hash(struct sock *sk);
int __inet_hash(struct sock *sk, struct sock *osk,
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard));
int inet_hash(struct sock *sk);
void inet_unhash(struct sock *sk);
struct sock *__inet_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const __be32 saddr, const __be16 sport,
const __be32 daddr,
const unsigned short hnum,
......@@ -220,10 +224,11 @@ struct sock *__inet_lookup_listener(struct net *net,
static inline struct sock *inet_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
__be32 saddr, __be16 sport,
__be32 daddr, __be16 dport, int dif)
{
return __inet_lookup_listener(net, hashinfo, saddr, sport,
return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
daddr, ntohs(dport), dif);
}
......@@ -299,6 +304,7 @@ static inline struct sock *
static inline struct sock *__inet_lookup(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const __be32 saddr, const __be16 sport,
const __be32 daddr, const __be16 dport,
const int dif)
......@@ -307,12 +313,13 @@ static inline struct sock *__inet_lookup(struct net *net,
struct sock *sk = __inet_lookup_established(net, hashinfo,
saddr, sport, daddr, hnum, dif);
return sk ? : __inet_lookup_listener(net, hashinfo, saddr, sport,
daddr, hnum, dif);
return sk ? : __inet_lookup_listener(net, hashinfo, skb, doff, saddr,
sport, daddr, hnum, dif);
}
static inline struct sock *inet_lookup(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const __be32 saddr, const __be16 sport,
const __be32 daddr, const __be16 dport,
const int dif)
......@@ -320,7 +327,8 @@ static inline struct sock *inet_lookup(struct net *net,
struct sock *sk;
local_bh_disable();
sk = __inet_lookup(net, hashinfo, saddr, sport, daddr, dport, dif);
sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
dport, dif);
local_bh_enable();
return sk;
......@@ -328,6 +336,7 @@ static inline struct sock *inet_lookup(struct net *net,
static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb,
int doff,
const __be16 sport,
const __be16 dport)
{
......@@ -337,8 +346,8 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
if (sk)
return sk;
else
return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo,
iph->saddr, sport,
return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
doff, iph->saddr, sport,
iph->daddr, dport, inet_iif(skb));
}
......
......@@ -51,7 +51,7 @@ void pn_sock_init(void);
struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *sa);
void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb);
void phonet_get_local_port_range(int *min, int *max);
void pn_sock_hash(struct sock *sk);
int pn_sock_hash(struct sock *sk);
void pn_sock_unhash(struct sock *sk);
int pn_sock_get_port(struct sock *sk, unsigned short sport);
......
......@@ -65,7 +65,7 @@ struct pingfakehdr {
};
int ping_get_port(struct sock *sk, unsigned short ident);
void ping_hash(struct sock *sk);
int ping_hash(struct sock *sk);
void ping_unhash(struct sock *sk);
int ping_init_sock(struct sock *sk);
......
......@@ -57,7 +57,7 @@ int raw_seq_open(struct inode *ino, struct file *file,
#endif
void raw_hash_sk(struct sock *sk);
int raw_hash_sk(struct sock *sk);
void raw_unhash_sk(struct sock *sk);
struct raw_sock {
......
......@@ -984,7 +984,7 @@ struct proto {
void (*release_cb)(struct sock *sk);
/* Keeping track of sk's, looking them up, and port selection methods. */
void (*hash)(struct sock *sk);
int (*hash)(struct sock *sk);
void (*unhash)(struct sock *sk);
void (*rehash)(struct sock *sk);
int (*get_port)(struct sock *sk, unsigned short snum);
......@@ -1194,10 +1194,10 @@ static inline void sock_prot_inuse_add(struct net *net, struct proto *prot,
/* With per-bucket locks this operation is not-atomic, so that
* this version is not worse.
*/
static inline void __sk_prot_rehash(struct sock *sk)
static inline int __sk_prot_rehash(struct sock *sk)
{
sk->sk_prot->unhash(sk);
sk->sk_prot->hash(sk);
return sk->sk_prot->hash(sk);
}
void sk_prot_clear_portaddr_nulls(struct sock *sk, int size);
......
......@@ -177,9 +177,10 @@ static inline struct udphdr *udp_gro_udphdr(struct sk_buff *skb)
}
/* hash routines shared between UDPv4/6 and UDP-Litev4/6 */
static inline void udp_lib_hash(struct sock *sk)
static inline int udp_lib_hash(struct sock *sk)
{
BUG();
return 0;
}
void udp_lib_unhash(struct sock *sk);
......
......@@ -1181,7 +1181,7 @@ static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk)
if (bpf_prog_size(prog->len) > sysctl_optmem_max)
return -ENOMEM;
if (sk_unhashed(sk)) {
if (sk_unhashed(sk) && sk->sk_reuseport) {
err = reuseport_alloc(sk);
if (err)
return err;
......
......@@ -1531,6 +1531,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
newsk = NULL;
goto out;
}
RCU_INIT_POINTER(newsk->sk_reuseport_cb, NULL);
newsk->sk_err = 0;
newsk->sk_priority = 0;
......
......@@ -802,7 +802,7 @@ static int dccp_v4_rcv(struct sk_buff *skb)
}
lookup:
sk = __inet_lookup_skb(&dccp_hashinfo, skb,
sk = __inet_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
dh->dccph_sport, dh->dccph_dport);
if (!sk) {
dccp_pr_debug("failed to look up flow ID in table and "
......
......@@ -668,7 +668,7 @@ static int dccp_v6_rcv(struct sk_buff *skb)
DCCP_SKB_CB(skb)->dccpd_ack_seq = dccp_hdr_ack_seq(skb);
lookup:
sk = __inet6_lookup_skb(&dccp_hashinfo, skb,
sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
dh->dccph_sport, dh->dccph_dport,
inet6_iif(skb));
if (!sk) {
......@@ -993,7 +993,7 @@ static struct proto dccp_v6_prot = {
.sendmsg = dccp_sendmsg,
.recvmsg = dccp_recvmsg,
.backlog_rcv = dccp_v6_do_rcv,
.hash = inet_hash,
.hash = inet6_hash,
.unhash = inet_unhash,
.accept = inet_csk_accept,
.get_port = inet_csk_get_port,
......
......@@ -182,12 +182,14 @@ static int ieee802154_sock_ioctl(struct socket *sock, unsigned int cmd,
static HLIST_HEAD(raw_head);
static DEFINE_RWLOCK(raw_lock);
static void raw_hash(struct sock *sk)
static int raw_hash(struct sock *sk)
{
write_lock_bh(&raw_lock);
sk_add_node(sk, &raw_head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
write_unlock_bh(&raw_lock);
return 0;
}
static void raw_unhash(struct sock *sk)
......@@ -462,12 +464,14 @@ static inline struct dgram_sock *dgram_sk(const struct sock *sk)
return container_of(sk, struct dgram_sock, sk);
}
static void dgram_hash(struct sock *sk)
static int dgram_hash(struct sock *sk)
{
write_lock_bh(&dgram_lock);
sk_add_node(sk, &dgram_head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
write_unlock_bh(&dgram_lock);
return 0;
}
static void dgram_unhash(struct sock *sk)
......@@ -1026,8 +1030,13 @@ static int ieee802154_create(struct net *net, struct socket *sock,
/* Checksums on by default */
sock_set_flag(sk, SOCK_ZAPPED);
if (sk->sk_prot->hash)
sk->sk_prot->hash(sk);
if (sk->sk_prot->hash) {
rc = sk->sk_prot->hash(sk);
if (rc) {
sk_common_release(sk);
goto out;
}
}
if (sk->sk_prot->init) {
rc = sk->sk_prot->init(sk);
......
......@@ -370,7 +370,11 @@ static int inet_create(struct net *net, struct socket *sock, int protocol,
*/
inet->inet_sport = htons(inet->inet_num);
/* Add to protocol hash chains. */
sk->sk_prot->hash(sk);
err = sk->sk_prot->hash(sk);
if (err) {
sk_common_release(sk);
goto out;
}
}
if (sk->sk_prot->init) {
......@@ -1142,8 +1146,7 @@ static int inet_sk_reselect_saddr(struct sock *sk)
* Besides that, it does not check for connection
* uniqueness. Wait for troubles.
*/
__sk_prot_rehash(sk);
return 0;
return __sk_prot_rehash(sk);
}
int inet_sk_rebuild_header(struct sock *sk)
......
......@@ -24,6 +24,7 @@
#include <net/tcp_states.h>
#include <net/xfrm.h>
#include <net/tcp.h>
#include <net/sock_reuseport.h>
#ifdef INET_CSK_DEBUG
const char inet_csk_timer_bug_msg[] = "inet_csk BUG: unknown timer value\n";
......@@ -67,6 +68,7 @@ int inet_csk_bind_conflict(const struct sock *sk,
if ((!reuse || !sk2->sk_reuse ||
sk2->sk_state == TCP_LISTEN) &&
(!reuseport || !sk2->sk_reuseport ||
rcu_access_pointer(sk->sk_reuseport_cb) ||
(sk2->sk_state != TCP_TIME_WAIT &&
!uid_eq(uid, sock_i_uid(sk2))))) {
......@@ -132,6 +134,7 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
sk->sk_state != TCP_LISTEN) ||
(tb->fastreuseport > 0 &&
sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(tb->fastuid, uid))) &&
(tb->num_owners < smallest_size || smallest_size == -1)) {
smallest_size = tb->num_owners;
......@@ -193,15 +196,18 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
if (((tb->fastreuse > 0 &&
sk->sk_reuse && sk->sk_state != TCP_LISTEN) ||
(tb->fastreuseport > 0 &&
sk->sk_reuseport && uid_eq(tb->fastuid, uid))) &&
smallest_size == -1) {
sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(tb->fastuid, uid))) && smallest_size == -1) {
goto success;
} else {
ret = 1;
if (inet_csk(sk)->icsk_af_ops->bind_conflict(sk, tb, true)) {
if (((sk->sk_reuse && sk->sk_state != TCP_LISTEN) ||
(tb->fastreuseport > 0 &&
sk->sk_reuseport && uid_eq(tb->fastuid, uid))) &&
sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(tb->fastuid, uid))) &&
smallest_size != -1 && --attempts >= 0) {
spin_unlock(&head->lock);
goto again;
......@@ -734,6 +740,7 @@ int inet_csk_listen_start(struct sock *sk, int backlog)
{
struct inet_connection_sock *icsk = inet_csk(sk);
struct inet_sock *inet = inet_sk(sk);
int err = -EADDRINUSE;
reqsk_queue_alloc(&icsk->icsk_accept_queue);
......@@ -751,13 +758,14 @@ int inet_csk_listen_start(struct sock *sk, int backlog)
inet->inet_sport = htons(inet->inet_num);
sk_dst_reset(sk);
sk->sk_prot->hash(sk);
err = sk->sk_prot->hash(sk);
if (likely(!err))
return 0;
}
sk->sk_state = TCP_CLOSE;
return -EADDRINUSE;
return err;
}
EXPORT_SYMBOL_GPL(inet_csk_listen_start);
......
......@@ -357,18 +357,18 @@ struct sock *inet_diag_find_one_icsk(struct net *net,
struct sock *sk;
if (req->sdiag_family == AF_INET)
sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
req->id.idiag_dport, req->id.idiag_src[0],
req->id.idiag_sport, req->id.idiag_if);
#if IS_ENABLED(CONFIG_IPV6)
else if (req->sdiag_family == AF_INET6) {
if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
sk = inet_lookup(net, hashinfo, req->id.idiag_dst[3],
sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
req->id.idiag_dport, req->id.idiag_src[3],
req->id.idiag_sport, req->id.idiag_if);
else
sk = inet6_lookup(net, hashinfo,
sk = inet6_lookup(net, hashinfo, NULL, 0,
(struct in6_addr *)req->id.idiag_dst,
req->id.idiag_dport,
(struct in6_addr *)req->id.idiag_src,
......
......@@ -20,10 +20,12 @@
#include <linux/wait.h>
#include <linux/vmalloc.h>
#include <net/addrconf.h>
#include <net/inet_connection_sock.h>
#include <net/inet_hashtables.h>
#include <net/secure_seq.h>
#include <net/ip.h>
#include <net/sock_reuseport.h>
static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
const __u16 lport, const __be32 faddr,
......@@ -205,6 +207,7 @@ static inline int compute_score(struct sock *sk, struct net *net,
struct sock *__inet_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const __be32 saddr, __be16 sport,
const __be32 daddr, const unsigned short hnum,
const int dif)
......@@ -214,6 +217,7 @@ struct sock *__inet_lookup_listener(struct net *net,
unsigned int hash = inet_lhashfn(net, hnum);
struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
int score, hiscore, matches = 0, reuseport = 0;
bool select_ok = true;
u32 phash = 0;
rcu_read_lock();
......@@ -229,6 +233,15 @@ struct sock *__inet_lookup_listener(struct net *net,
if (reuseport) {
phash = inet_ehashfn(net, daddr, hnum,
saddr, sport);
if (select_ok) {
struct sock *sk2;
sk2 = reuseport_select_sock(sk, phash,
skb, doff);
if (sk2) {
result = sk2;
goto found;
}
}
matches = 1;
}
} else if (score == hiscore && reuseport) {
......@@ -246,11 +259,13 @@ struct sock *__inet_lookup_listener(struct net *net,
if (get_nulls_value(node) != hash + LISTENING_NULLS_BASE)
goto begin;
if (result) {
found:
if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
result = NULL;
else if (unlikely(compute_score(result, net, hnum, daddr,
dif) < hiscore)) {
sock_put(result);
select_ok = false;
goto begin;
}
}
......@@ -449,32 +464,74 @@ bool inet_ehash_nolisten(struct sock *sk, struct sock *osk)
}
EXPORT_SYMBOL_GPL(inet_ehash_nolisten);
void __inet_hash(struct sock *sk, struct sock *osk)
static int inet_reuseport_add_sock(struct sock *sk,
struct inet_listen_hashbucket *ilb,
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard))
{
struct sock *sk2;
struct hlist_nulls_node *node;
kuid_t uid = sock_i_uid(sk);
sk_nulls_for_each_rcu(sk2, node, &ilb->head) {
if (sk2 != sk &&
sk2->sk_family == sk->sk_family &&
ipv6_only_sock(sk2) == ipv6_only_sock(sk) &&
sk2->sk_bound_dev_if == sk->sk_bound_dev_if &&
sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
saddr_same(sk, sk2, false))
return reuseport_add_sock(sk, sk2);
}
/* Initial allocation may have already happened via setsockopt */
if (!rcu_access_pointer(sk->sk_reuseport_cb))
return reuseport_alloc(sk);
return 0;
}
int __inet_hash(struct sock *sk, struct sock *osk,
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard))
{
struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
struct inet_listen_hashbucket *ilb;
int err = 0;
if (sk->sk_state != TCP_LISTEN) {
inet_ehash_nolisten(sk, osk);
return;
return 0;
}
WARN_ON(!sk_unhashed(sk));
ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
spin_lock(&ilb->lock);
if (sk->sk_reuseport) {
err = inet_reuseport_add_sock(sk, ilb, saddr_same);
if (err)
goto unlock;
}
__sk_nulls_add_node_rcu(sk, &ilb->head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
unlock:
spin_unlock(&ilb->lock);
return err;
}
EXPORT_SYMBOL(__inet_hash);
void inet_hash(struct sock *sk)
int inet_hash(struct sock *sk)
{
int err = 0;
if (sk->sk_state != TCP_CLOSE) {
local_bh_disable();
__inet_hash(sk, NULL);
err = __inet_hash(sk, NULL, ipv4_rcv_saddr_equal);
local_bh_enable();
}
return err;
}
EXPORT_SYMBOL_GPL(inet_hash);
......@@ -493,6 +550,8 @@ void inet_unhash(struct sock *sk)
lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
spin_lock_bh(lock);
if (rcu_access_pointer(sk->sk_reuseport_cb))
reuseport_detach_sock(sk);
done = __sk_nulls_del_node_init_rcu(sk);
if (done)
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
......
......@@ -145,10 +145,12 @@ int ping_get_port(struct sock *sk, unsigned short ident)
}
EXPORT_SYMBOL_GPL(ping_get_port);
void ping_hash(struct sock *sk)
int ping_hash(struct sock *sk)
{
pr_debug("ping_hash(sk->port=%u)\n", inet_sk(sk)->inet_num);
BUG(); /* "Please do not press this button again." */
return 0;
}
void ping_unhash(struct sock *sk)
......
......@@ -93,7 +93,7 @@ static struct raw_hashinfo raw_v4_hashinfo = {
.lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock),
};
void raw_hash_sk(struct sock *sk)
int raw_hash_sk(struct sock *sk)
{
struct raw_hashinfo *h = sk->sk_prot->h.raw_hash;
struct hlist_head *head;
......@@ -104,6 +104,8 @@ void raw_hash_sk(struct sock *sk)
sk_add_node(sk, head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
write_unlock_bh(&h->lock);
return 0;
}
EXPORT_SYMBOL_GPL(raw_hash_sk);
......
......@@ -637,8 +637,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
* Incoming packet is checked with md5 hash with finding key,
* no RST generated if md5 hash doesn't match.
*/
sk1 = __inet_lookup_listener(net,
&tcp_hashinfo, ip_hdr(skb)->saddr,
sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0,
ip_hdr(skb)->saddr,
th->source, ip_hdr(skb)->daddr,
ntohs(th->source), inet_iif(skb));
/* don't send rst if it can't find key */
......@@ -1581,7 +1581,8 @@ int tcp_v4_rcv(struct sk_buff *skb)
TCP_SKB_CB(skb)->sacked = 0;
lookup:
sk = __inet_lookup_skb(&tcp_hashinfo, skb, th->source, th->dest);
sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source,
th->dest);
if (!sk)
goto no_tcp_socket;
......@@ -1695,7 +1696,8 @@ int tcp_v4_rcv(struct sk_buff *skb)
switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) {
case TCP_TW_SYN: {
struct sock *sk2 = inet_lookup_listener(dev_net(skb->dev),
&tcp_hashinfo,
&tcp_hashinfo, skb,
__tcp_hdrlen(th),
iph->saddr, th->source,
iph->daddr, th->dest,
inet_iif(skb));
......
......@@ -356,7 +356,7 @@ EXPORT_SYMBOL(udp_lib_get_port);
* match_wildcard == false: addresses must be exactly the same, i.e.
* 0.0.0.0 only equals to 0.0.0.0
*/
static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2,
int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2,
bool match_wildcard)
{
struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
......
......@@ -235,7 +235,11 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol,
* creation time automatically shares.
*/
inet->inet_sport = htons(inet->inet_num);
sk->sk_prot->hash(sk);
err = sk->sk_prot->hash(sk);
if (err) {
sk_common_release(sk);
goto out;
}
}
if (sk->sk_prot->init) {
err = sk->sk_prot->init(sk);
......
......@@ -26,6 +26,7 @@
#include <net/ip6_route.h>
#include <net/sock.h>
#include <net/inet6_connection_sock.h>
#include <net/sock_reuseport.h>
int inet6_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb, bool relax)
......@@ -48,6 +49,7 @@ int inet6_csk_bind_conflict(const struct sock *sk,
if ((!reuse || !sk2->sk_reuse ||
sk2->sk_state == TCP_LISTEN) &&
(!reuseport || !sk2->sk_reuseport ||
rcu_access_pointer(sk->sk_reuseport_cb) ||
(sk2->sk_state != TCP_TIME_WAIT &&
!uid_eq(uid,
sock_i_uid((struct sock *)sk2))))) {
......
......@@ -17,11 +17,13 @@
#include <linux/module.h>
#include <linux/random.h>
#include <net/addrconf.h>
#include <net/inet_connection_sock.h>
#include <net/inet_hashtables.h>
#include <net/inet6_hashtables.h>
#include <net/secure_seq.h>
#include <net/ip.h>
#include <net/sock_reuseport.h>
u32 inet6_ehashfn(const struct net *net,
const struct in6_addr *laddr, const u16 lport,
......@@ -121,7 +123,9 @@ static inline int compute_score(struct sock *sk, struct net *net,
}
struct sock *inet6_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo, const struct in6_addr *saddr,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
const __be16 sport, const struct in6_addr *daddr,
const unsigned short hnum, const int dif)
{
......@@ -129,6 +133,7 @@ struct sock *inet6_lookup_listener(struct net *net,
const struct hlist_nulls_node *node;
struct sock *result;
int score, hiscore, matches = 0, reuseport = 0;
bool select_ok = true;
u32 phash = 0;
unsigned int hash = inet_lhashfn(net, hnum);
struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
......@@ -146,6 +151,15 @@ struct sock *inet6_lookup_listener(struct net *net,
if (reuseport) {
phash = inet6_ehashfn(net, daddr, hnum,
saddr, sport);
if (select_ok) {
struct sock *sk2;
sk2 = reuseport_select_sock(sk, phash,
skb, doff);
if (sk2) {
result = sk2;
goto found;
}
}
matches = 1;
}
} else if (score == hiscore && reuseport) {
......@@ -163,11 +177,13 @@ struct sock *inet6_lookup_listener(struct net *net,
if (get_nulls_value(node) != hash + LISTENING_NULLS_BASE)
goto begin;
if (result) {
found:
if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
result = NULL;
else if (unlikely(compute_score(result, net, hnum, daddr,
dif) < hiscore)) {
sock_put(result);
select_ok = false;
goto begin;
}
}
......@@ -177,6 +193,7 @@ struct sock *inet6_lookup_listener(struct net *net,
EXPORT_SYMBOL_GPL(inet6_lookup_listener);
struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr, const __be16 sport,
const struct in6_addr *daddr, const __be16 dport,
const int dif)
......@@ -184,7 +201,8 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
struct sock *sk;
local_bh_disable();
sk = __inet6_lookup(net, hashinfo, saddr, sport, daddr, ntohs(dport), dif);
sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
ntohs(dport), dif);
local_bh_enable();
return sk;
......@@ -274,3 +292,59 @@ int inet6_hash_connect(struct inet_timewait_death_row *death_row,
__inet6_check_established);
}
EXPORT_SYMBOL_GPL(inet6_hash_connect);
int inet6_hash(struct sock *sk)
{
if (sk->sk_state != TCP_CLOSE) {
local_bh_disable();
__inet_hash(sk, NULL, ipv6_rcv_saddr_equal);
local_bh_enable();
}
return 0;
}
EXPORT_SYMBOL_GPL(inet6_hash);
/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
* only, and any IPv4 addresses if not IPv6 only
* match_wildcard == false: addresses must be exactly the same, i.e.
* IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
* and 0.0.0.0 equals to 0.0.0.0 only
*/
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard)
{
const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
int sk2_ipv6only = inet_v6_ipv6only(sk2);
int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
/* if both are mapped, treat as IPv4 */
if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
if (!sk2_ipv6only) {
if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
return 1;
if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
return match_wildcard;
}
return 0;
}
if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
return 1;
if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
!(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
return 1;
if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
!(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
return 1;
if (sk2_rcv_saddr6 &&
ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6))
return 1;
return 0;
}
EXPORT_SYMBOL_GPL(ipv6_rcv_saddr_equal);
......@@ -866,7 +866,8 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
* no RST generated if md5 hash doesn't match.
*/
sk1 = inet6_lookup_listener(dev_net(skb_dst(skb)->dev),
&tcp_hashinfo, &ipv6h->saddr,
&tcp_hashinfo, NULL, 0,
&ipv6h->saddr,
th->source, &ipv6h->daddr,
ntohs(th->source), tcp_v6_iif(skb));
if (!sk1)
......@@ -1375,8 +1376,8 @@ static int tcp_v6_rcv(struct sk_buff *skb)
hdr = ipv6_hdr(skb);
lookup:
sk = __inet6_lookup_skb(&tcp_hashinfo, skb, th->source, th->dest,
inet6_iif(skb));
sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th),
th->source, th->dest, inet6_iif(skb));
if (!sk)
goto no_tcp_socket;
......@@ -1500,6 +1501,7 @@ static int tcp_v6_rcv(struct sk_buff *skb)
struct sock *sk2;
sk2 = inet6_lookup_listener(dev_net(skb->dev), &tcp_hashinfo,
skb, __tcp_hdrlen(th),
&ipv6_hdr(skb)->saddr, th->source,
&ipv6_hdr(skb)->daddr,
ntohs(th->dest), tcp_v6_iif(skb));
......@@ -1865,7 +1867,7 @@ struct proto tcpv6_prot = {
.sendpage = tcp_sendpage,
.backlog_rcv = tcp_v6_do_rcv,
.release_cb = tcp_release_cb,
.hash = inet_hash,
.hash = inet6_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
.enter_memory_pressure = tcp_enter_memory_pressure,
......
......@@ -37,6 +37,7 @@
#include <linux/slab.h>
#include <asm/uaccess.h>
#include <net/addrconf.h>
#include <net/ndisc.h>
#include <net/protocol.h>
#include <net/transp_v6.h>
......@@ -77,49 +78,6 @@ static u32 udp6_ehashfn(const struct net *net,
udp_ipv6_hash_secret + net_hash_mix(net));
}
/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
* only, and any IPv4 addresses if not IPv6 only
* match_wildcard == false: addresses must be exactly the same, i.e.
* IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
* and 0.0.0.0 equals to 0.0.0.0 only
*/
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard)
{
const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
int sk2_ipv6only = inet_v6_ipv6only(sk2);
int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
/* if both are mapped, treat as IPv4 */
if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
if (!sk2_ipv6only) {
if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
return 1;
if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
return match_wildcard;
}
return 0;
}
if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
return 1;
if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
!(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
return 1;
if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
!(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
return 1;
if (sk2_rcv_saddr6 &&
ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6))
return 1;
return 0;
}
static u32 udp6_portaddr_hash(const struct net *net,
const struct in6_addr *addr6,
unsigned int port)
......
......@@ -25,6 +25,7 @@
#include <net/udp.h>
#include <net/inet_common.h>
#include <net/inet_hashtables.h>
#include <net/inet6_hashtables.h>
#include <net/tcp_states.h>
#include <net/protocol.h>
#include <net/xfrm.h>
......@@ -718,7 +719,7 @@ static struct proto l2tp_ip6_prot = {
.sendmsg = l2tp_ip6_sendmsg,
.recvmsg = l2tp_ip6_recvmsg,
.backlog_rcv = l2tp_ip6_backlog_recv,
.hash = inet_hash,
.hash = inet6_hash,
.unhash = inet_unhash,
.obj_size = sizeof(struct l2tp_ip6_sock),
#ifdef CONFIG_COMPAT
......
......@@ -105,19 +105,24 @@ tproxy_laddr4(struct sk_buff *skb, __be32 user_laddr, __be32 daddr)
* belonging to established connections going through that one.
*/
static inline struct sock *
nf_tproxy_get_sock_v4(struct net *net, const u8 protocol,
nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, void *hp,
const u8 protocol,
const __be32 saddr, const __be32 daddr,
const __be16 sport, const __be16 dport,
const struct net_device *in,
const enum nf_tproxy_lookup_t lookup_type)
{
struct sock *sk;
struct tcphdr *tcph;
switch (protocol) {
case IPPROTO_TCP:
switch (lookup_type) {
case NFT_LOOKUP_LISTENER:
sk = inet_lookup_listener(net, &tcp_hashinfo,
tcph = hp;
sk = inet_lookup_listener(net, &tcp_hashinfo, skb,
ip_hdrlen(skb) +
__tcp_hdrlen(tcph),
saddr, sport,
daddr, dport,
in->ifindex);
......@@ -169,19 +174,23 @@ nf_tproxy_get_sock_v4(struct net *net, const u8 protocol,
#ifdef XT_TPROXY_HAVE_IPV6
static inline struct sock *
nf_tproxy_get_sock_v6(struct net *net, const u8 protocol,
nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp,
const u8 protocol,
const struct in6_addr *saddr, const struct in6_addr *daddr,
const __be16 sport, const __be16 dport,
const struct net_device *in,
const enum nf_tproxy_lookup_t lookup_type)
{
struct sock *sk;
struct tcphdr *tcph;
switch (protocol) {
case IPPROTO_TCP:
switch (lookup_type) {
case NFT_LOOKUP_LISTENER:
sk = inet6_lookup_listener(net, &tcp_hashinfo,
tcph = hp;
sk = inet6_lookup_listener(net, &tcp_hashinfo, skb,
thoff + __tcp_hdrlen(tcph),
saddr, sport,
daddr, ntohs(dport),
in->ifindex);
......@@ -267,7 +276,7 @@ tproxy_handle_time_wait4(struct net *net, struct sk_buff *skb,
* to a listener socket if there's one */
struct sock *sk2;
sk2 = nf_tproxy_get_sock_v4(net, iph->protocol,
sk2 = nf_tproxy_get_sock_v4(net, skb, hp, iph->protocol,
iph->saddr, laddr ? laddr : iph->daddr,
hp->source, lport ? lport : hp->dest,
skb->dev, NFT_LOOKUP_LISTENER);
......@@ -305,7 +314,7 @@ tproxy_tg4(struct net *net, struct sk_buff *skb, __be32 laddr, __be16 lport,
* addresses, this happens if the redirect already happened
* and the current packet belongs to an already established
* connection */
sk = nf_tproxy_get_sock_v4(net, iph->protocol,
sk = nf_tproxy_get_sock_v4(net, skb, hp, iph->protocol,
iph->saddr, iph->daddr,
hp->source, hp->dest,
skb->dev, NFT_LOOKUP_ESTABLISHED);
......@@ -321,7 +330,7 @@ tproxy_tg4(struct net *net, struct sk_buff *skb, __be32 laddr, __be16 lport,
else if (!sk)
/* no, there's no established connection, check if
* there's a listener on the redirected addr/port */
sk = nf_tproxy_get_sock_v4(net, iph->protocol,
sk = nf_tproxy_get_sock_v4(net, skb, hp, iph->protocol,
iph->saddr, laddr,
hp->source, lport,
skb->dev, NFT_LOOKUP_LISTENER);
......@@ -429,7 +438,7 @@ tproxy_handle_time_wait6(struct sk_buff *skb, int tproto, int thoff,
* to a listener socket if there's one */
struct sock *sk2;
sk2 = nf_tproxy_get_sock_v6(par->net, tproto,
sk2 = nf_tproxy_get_sock_v6(par->net, skb, thoff, hp, tproto,
&iph->saddr,
tproxy_laddr6(skb, &tgi->laddr.in6, &iph->daddr),
hp->source,
......@@ -472,7 +481,7 @@ tproxy_tg6_v1(struct sk_buff *skb, const struct xt_action_param *par)
* addresses, this happens if the redirect already happened
* and the current packet belongs to an already established
* connection */
sk = nf_tproxy_get_sock_v6(par->net, tproto,
sk = nf_tproxy_get_sock_v6(par->net, skb, thoff, hp, tproto,
&iph->saddr, &iph->daddr,
hp->source, hp->dest,
par->in, NFT_LOOKUP_ESTABLISHED);
......@@ -487,8 +496,8 @@ tproxy_tg6_v1(struct sk_buff *skb, const struct xt_action_param *par)
else if (!sk)
/* no there's no established connection, check if
* there's a listener on the redirected addr/port */
sk = nf_tproxy_get_sock_v6(par->net, tproto,
&iph->saddr, laddr,
sk = nf_tproxy_get_sock_v6(par->net, skb, thoff, hp,
tproto, &iph->saddr, laddr,
hp->source, lport,
par->in, NFT_LOOKUP_LISTENER);
......
......@@ -112,14 +112,15 @@ extract_icmp4_fields(const struct sk_buff *skb,
* box.
*/
static struct sock *
xt_socket_get_sock_v4(struct net *net, const u8 protocol,
xt_socket_get_sock_v4(struct net *net, struct sk_buff *skb, const int doff,
const u8 protocol,
const __be32 saddr, const __be32 daddr,
const __be16 sport, const __be16 dport,
const struct net_device *in)
{
switch (protocol) {
case IPPROTO_TCP:
return __inet_lookup(net, &tcp_hashinfo,
return __inet_lookup(net, &tcp_hashinfo, skb, doff,
saddr, sport, daddr, dport,
in->ifindex);
case IPPROTO_UDP:
......@@ -148,6 +149,8 @@ static struct sock *xt_socket_lookup_slow_v4(struct net *net,
const struct net_device *indev)
{
const struct iphdr *iph = ip_hdr(skb);
struct sk_buff *data_skb = NULL;
int doff = 0;
__be32 uninitialized_var(daddr), uninitialized_var(saddr);
__be16 uninitialized_var(dport), uninitialized_var(sport);
u8 uninitialized_var(protocol);
......@@ -169,6 +172,10 @@ static struct sock *xt_socket_lookup_slow_v4(struct net *net,
sport = hp->source;
daddr = iph->daddr;
dport = hp->dest;
data_skb = (struct sk_buff *)skb;
doff = iph->protocol == IPPROTO_TCP ?
ip_hdrlen(skb) + __tcp_hdrlen((struct tcphdr *)hp) :
ip_hdrlen(skb) + sizeof(*hp);
} else if (iph->protocol == IPPROTO_ICMP) {
if (extract_icmp4_fields(skb, &protocol, &saddr, &daddr,
......@@ -198,8 +205,8 @@ static struct sock *xt_socket_lookup_slow_v4(struct net *net,
}
#endif
return xt_socket_get_sock_v4(net, protocol, saddr, daddr,
sport, dport, indev);
return xt_socket_get_sock_v4(net, data_skb, doff, protocol, saddr,
daddr, sport, dport, indev);
}
static bool
......@@ -318,14 +325,15 @@ extract_icmp6_fields(const struct sk_buff *skb,
}
static struct sock *
xt_socket_get_sock_v6(struct net *net, const u8 protocol,
xt_socket_get_sock_v6(struct net *net, struct sk_buff *skb, int doff,
const u8 protocol,
const struct in6_addr *saddr, const struct in6_addr *daddr,
const __be16 sport, const __be16 dport,
const struct net_device *in)
{
switch (protocol) {
case IPPROTO_TCP:
return inet6_lookup(net, &tcp_hashinfo,
return inet6_lookup(net, &tcp_hashinfo, skb, doff,
saddr, sport, daddr, dport,
in->ifindex);
case IPPROTO_UDP:
......@@ -343,6 +351,8 @@ static struct sock *xt_socket_lookup_slow_v6(struct net *net,
__be16 uninitialized_var(dport), uninitialized_var(sport);
const struct in6_addr *daddr = NULL, *saddr = NULL;
struct ipv6hdr *iph = ipv6_hdr(skb);
struct sk_buff *data_skb = NULL;
int doff = 0;
int thoff = 0, tproto;
tproto = ipv6_find_hdr(skb, &thoff, -1, NULL, NULL);
......@@ -362,6 +372,10 @@ static struct sock *xt_socket_lookup_slow_v6(struct net *net,
sport = hp->source;
daddr = &iph->daddr;
dport = hp->dest;
data_skb = (struct sk_buff *)skb;
doff = tproto == IPPROTO_TCP ?
thoff + __tcp_hdrlen((struct tcphdr *)hp) :
thoff + sizeof(*hp);
} else if (tproto == IPPROTO_ICMPV6) {
struct ipv6hdr ipv6_var;
......@@ -373,7 +387,7 @@ static struct sock *xt_socket_lookup_slow_v6(struct net *net,
return NULL;
}
return xt_socket_get_sock_v6(net, tproto, saddr, daddr,
return xt_socket_get_sock_v6(net, data_skb, doff, tproto, saddr, daddr,
sport, dport, indev);
}
......
......@@ -140,13 +140,15 @@ void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb)
rcu_read_unlock();
}
void pn_sock_hash(struct sock *sk)
int pn_sock_hash(struct sock *sk)
{
struct hlist_head *hlist = pn_hash_list(pn_sk(sk)->sobject);
mutex_lock(&pnsocks.lock);
sk_add_node_rcu(sk, hlist);
mutex_unlock(&pnsocks.lock);
return 0;
}
EXPORT_SYMBOL(pn_sock_hash);
......@@ -200,7 +202,7 @@ static int pn_socket_bind(struct socket *sock, struct sockaddr *addr, int len)
pn->resource = spn->spn_resource;
/* Enable RX on the socket */
sk->sk_prot->hash(sk);
err = sk->sk_prot->hash(sk);
out_port:
mutex_unlock(&port_mutex);
out:
......
......@@ -6101,9 +6101,10 @@ static int sctp_getsockopt(struct sock *sk, int level, int optname,
return retval;
}
static void sctp_hash(struct sock *sk)
static int sctp_hash(struct sock *sk)
{
/* STUB */
return 0;
}
static void sctp_unhash(struct sock *sk)
......
......@@ -2,3 +2,4 @@ socket
psock_fanout
psock_tpacket
reuseport_bpf
reuseport_bpf_cpu
......@@ -4,7 +4,7 @@ CFLAGS = -Wall -O2 -g
CFLAGS += -I../../../../usr/include/
NET_PROGS = socket psock_fanout psock_tpacket reuseport_bpf
NET_PROGS = socket psock_fanout psock_tpacket reuseport_bpf reuseport_bpf_cpu
all: $(NET_PROGS)
%: %.c
......
......@@ -9,10 +9,12 @@
#include <errno.h>
#include <error.h>
#include <fcntl.h>
#include <linux/bpf.h>
#include <linux/filter.h>
#include <linux/unistd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
......@@ -169,10 +171,16 @@ static void build_recv_group(const struct test_params p, int fd[], uint16_t mod,
if (bind(fd[i], addr, sockaddr_size()))
error(1, errno, "failed to bind recv socket %d", i);
if (p.protocol == SOCK_STREAM)
if (p.protocol == SOCK_STREAM) {
opt = 4;
if (setsockopt(fd[i], SOL_TCP, TCP_FASTOPEN, &opt,
sizeof(opt)))
error(1, errno,
"failed to set TCP_FASTOPEN on %d", i);
if (listen(fd[i], p.recv_socks * 10))
error(1, errno, "failed to listen on socket");
}
}
free(addr);
}
......@@ -189,10 +197,8 @@ static void send_from(struct test_params p, uint16_t sport, char *buf,
if (bind(fd, saddr, sockaddr_size()))
error(1, errno, "failed to bind send socket");
if (connect(fd, daddr, sockaddr_size()))
error(1, errno, "failed to connect");
if (send(fd, buf, len, 0) < 0)
if (sendto(fd, buf, len, MSG_FASTOPEN, daddr, sockaddr_size()) < 0)
error(1, errno, "failed to send message");
close(fd);
......@@ -260,7 +266,7 @@ static void test_recv_order(const struct test_params p, int fd[], int mod)
}
}
static void test_reuseport_ebpf(const struct test_params p)
static void test_reuseport_ebpf(struct test_params p)
{
int i, fd[p.recv_socks];
......@@ -268,6 +274,7 @@ static void test_reuseport_ebpf(const struct test_params p)
build_recv_group(p, fd, p.recv_socks, attach_ebpf);
test_recv_order(p, fd, p.recv_socks);
p.send_port_min += p.recv_socks * 2;
fprintf(stderr, "Reprograming, testing mod %zd...\n", p.recv_socks / 2);
attach_ebpf(fd[0], p.recv_socks / 2);
test_recv_order(p, fd, p.recv_socks / 2);
......@@ -276,7 +283,7 @@ static void test_reuseport_ebpf(const struct test_params p)
close(fd[i]);
}
static void test_reuseport_cbpf(const struct test_params p)
static void test_reuseport_cbpf(struct test_params p)
{
int i, fd[p.recv_socks];
......@@ -284,6 +291,7 @@ static void test_reuseport_cbpf(const struct test_params p)
build_recv_group(p, fd, p.recv_socks, attach_cbpf);
test_recv_order(p, fd, p.recv_socks);
p.send_port_min += p.recv_socks * 2;
fprintf(stderr, "Reprograming, testing mod %zd...\n", p.recv_socks / 2);
attach_cbpf(fd[0], p.recv_socks / 2);
test_recv_order(p, fd, p.recv_socks / 2);
......@@ -377,7 +385,7 @@ static void test_filter_no_reuseport(const struct test_params p)
static void test_filter_without_bind(void)
{
int fd1, fd2;
int fd1, fd2, opt = 1;
fprintf(stderr, "Testing filter add without bind...\n");
fd1 = socket(AF_INET, SOCK_DGRAM, 0);
......@@ -386,6 +394,10 @@ static void test_filter_without_bind(void)
fd2 = socket(AF_INET, SOCK_DGRAM, 0);
if (fd2 < 0)
error(1, errno, "failed to create socket 2");
if (setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)))
error(1, errno, "failed to set SO_REUSEPORT on socket 1");
if (setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)))
error(1, errno, "failed to set SO_REUSEPORT on socket 2");
attach_ebpf(fd1, 10);
attach_cbpf(fd2, 10);
......@@ -394,6 +406,32 @@ static void test_filter_without_bind(void)
close(fd2);
}
void enable_fastopen(void)
{
int fd = open("/proc/sys/net/ipv4/tcp_fastopen", 0);
int rw_mask = 3; /* bit 1: client side; bit-2 server side */
int val, size;
char buf[16];
if (fd < 0)
error(1, errno, "Unable to open tcp_fastopen sysctl");
if (read(fd, buf, sizeof(buf)) <= 0)
error(1, errno, "Unable to read tcp_fastopen sysctl");
val = atoi(buf);
close(fd);
if ((val & rw_mask) != rw_mask) {
fd = open("/proc/sys/net/ipv4/tcp_fastopen", O_RDWR);
if (fd < 0)
error(1, errno,
"Unable to open tcp_fastopen sysctl for writing");
val |= rw_mask;
size = snprintf(buf, 16, "%d", val);
if (write(fd, buf, size) <= 0)
error(1, errno, "Unable to write tcp_fastopen sysctl");
close(fd);
}
}
int main(void)
{
......@@ -506,6 +544,71 @@ int main(void)
.recv_port = 8007,
.send_port_min = 9100});
/* TCP fastopen is required for the TCP tests */
enable_fastopen();
fprintf(stderr, "---- IPv4 TCP ----\n");
test_reuseport_ebpf((struct test_params) {
.recv_family = AF_INET,
.send_family = AF_INET,
.protocol = SOCK_STREAM,
.recv_socks = 10,
.recv_port = 8008,
.send_port_min = 9120});
test_reuseport_cbpf((struct test_params) {
.recv_family = AF_INET,
.send_family = AF_INET,
.protocol = SOCK_STREAM,
.recv_socks = 10,
.recv_port = 8009,
.send_port_min = 9160});
test_extra_filter((struct test_params) {
.recv_family = AF_INET,
.protocol = SOCK_STREAM,
.recv_port = 8010});
test_filter_no_reuseport((struct test_params) {
.recv_family = AF_INET,
.protocol = SOCK_STREAM,
.recv_port = 8011});
fprintf(stderr, "---- IPv6 TCP ----\n");
test_reuseport_ebpf((struct test_params) {
.recv_family = AF_INET6,
.send_family = AF_INET6,
.protocol = SOCK_STREAM,
.recv_socks = 10,
.recv_port = 8012,
.send_port_min = 9200});
test_reuseport_cbpf((struct test_params) {
.recv_family = AF_INET6,
.send_family = AF_INET6,
.protocol = SOCK_STREAM,
.recv_socks = 10,
.recv_port = 8013,
.send_port_min = 9240});
test_extra_filter((struct test_params) {
.recv_family = AF_INET6,
.protocol = SOCK_STREAM,
.recv_port = 8014});
test_filter_no_reuseport((struct test_params) {
.recv_family = AF_INET6,
.protocol = SOCK_STREAM,
.recv_port = 8015});
fprintf(stderr, "---- IPv6 TCP w/ mapped IPv4 ----\n");
test_reuseport_ebpf((struct test_params) {
.recv_family = AF_INET6,
.send_family = AF_INET,
.protocol = SOCK_STREAM,
.recv_socks = 10,
.recv_port = 8016,
.send_port_min = 9320});
test_reuseport_cbpf((struct test_params) {
.recv_family = AF_INET6,
.send_family = AF_INET,
.protocol = SOCK_STREAM,
.recv_socks = 10,
.recv_port = 8017,
.send_port_min = 9360});
test_filter_without_bind();
......
/*
* Test functionality of BPF filters with SO_REUSEPORT. This program creates
* an SO_REUSEPORT receiver group containing one socket per CPU core. It then
* creates a BPF program that will select a socket from this group based
* on the core id that receives the packet. The sending code artificially
* moves itself to run on different core ids and sends one message from
* each core. Since these packets are delivered over loopback, they should
* arrive on the same core that sent them. The receiving code then ensures
* that the packet was received on the socket for the corresponding core id.
* This entire process is done for several different core id permutations
* and for each IPv4/IPv6 and TCP/UDP combination.
*/
#define _GNU_SOURCE
#include <arpa/inet.h>
#include <errno.h>
#include <error.h>
#include <linux/filter.h>
#include <linux/in.h>
#include <linux/unistd.h>
#include <sched.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
static const int PORT = 8888;
static void build_rcv_group(int *rcv_fd, size_t len, int family, int proto)
{
struct sockaddr_storage addr;
struct sockaddr_in *addr4;
struct sockaddr_in6 *addr6;
size_t i;
int opt;
switch (family) {
case AF_INET:
addr4 = (struct sockaddr_in *)&addr;
addr4->sin_family = AF_INET;
addr4->sin_addr.s_addr = htonl(INADDR_ANY);
addr4->sin_port = htons(PORT);
break;
case AF_INET6:
addr6 = (struct sockaddr_in6 *)&addr;
addr6->sin6_family = AF_INET6;
addr6->sin6_addr = in6addr_any;
addr6->sin6_port = htons(PORT);
break;
default:
error(1, 0, "Unsupported family %d", family);
}
for (i = 0; i < len; ++i) {
rcv_fd[i] = socket(family, proto, 0);
if (rcv_fd[i] < 0)
error(1, errno, "failed to create receive socket");
opt = 1;
if (setsockopt(rcv_fd[i], SOL_SOCKET, SO_REUSEPORT, &opt,
sizeof(opt)))
error(1, errno, "failed to set SO_REUSEPORT");
if (bind(rcv_fd[i], (struct sockaddr *)&addr, sizeof(addr)))
error(1, errno, "failed to bind receive socket");
if (proto == SOCK_STREAM && listen(rcv_fd[i], len * 10))
error(1, errno, "failed to listen on receive port");
}
}
static void attach_bpf(int fd)
{
struct sock_filter code[] = {
/* A = raw_smp_processor_id() */
{ BPF_LD | BPF_W | BPF_ABS, 0, 0, SKF_AD_OFF + SKF_AD_CPU },
/* return A */
{ BPF_RET | BPF_A, 0, 0, 0 },
};
struct sock_fprog p = {
.len = 2,
.filter = code,
};
if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF, &p, sizeof(p)))
error(1, errno, "failed to set SO_ATTACH_REUSEPORT_CBPF");
}
static void send_from_cpu(int cpu_id, int family, int proto)
{
struct sockaddr_storage saddr, daddr;
struct sockaddr_in *saddr4, *daddr4;
struct sockaddr_in6 *saddr6, *daddr6;
cpu_set_t cpu_set;
int fd;
switch (family) {
case AF_INET:
saddr4 = (struct sockaddr_in *)&saddr;
saddr4->sin_family = AF_INET;
saddr4->sin_addr.s_addr = htonl(INADDR_ANY);
saddr4->sin_port = 0;
daddr4 = (struct sockaddr_in *)&daddr;
daddr4->sin_family = AF_INET;
daddr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
daddr4->sin_port = htons(PORT);
break;
case AF_INET6:
saddr6 = (struct sockaddr_in6 *)&saddr;
saddr6->sin6_family = AF_INET6;
saddr6->sin6_addr = in6addr_any;
saddr6->sin6_port = 0;
daddr6 = (struct sockaddr_in6 *)&daddr;
daddr6->sin6_family = AF_INET6;
daddr6->sin6_addr = in6addr_loopback;
daddr6->sin6_port = htons(PORT);
break;
default:
error(1, 0, "Unsupported family %d", family);
}
memset(&cpu_set, 0, sizeof(cpu_set));
CPU_SET(cpu_id, &cpu_set);
if (sched_setaffinity(0, sizeof(cpu_set), &cpu_set) < 0)
error(1, errno, "failed to pin to cpu");
fd = socket(family, proto, 0);
if (fd < 0)
error(1, errno, "failed to create send socket");
if (bind(fd, (struct sockaddr *)&saddr, sizeof(saddr)))
error(1, errno, "failed to bind send socket");
if (connect(fd, (struct sockaddr *)&daddr, sizeof(daddr)))
error(1, errno, "failed to connect send socket");
if (send(fd, "a", 1, 0) < 0)
error(1, errno, "failed to send message");
close(fd);
}
static
void receive_on_cpu(int *rcv_fd, int len, int epfd, int cpu_id, int proto)
{
struct epoll_event ev;
int i, fd;
char buf[8];
i = epoll_wait(epfd, &ev, 1, -1);
if (i < 0)
error(1, errno, "epoll_wait failed");
if (proto == SOCK_STREAM) {
fd = accept(ev.data.fd, NULL, NULL);
if (fd < 0)
error(1, errno, "failed to accept");
i = recv(fd, buf, sizeof(buf), 0);
close(fd);
} else {
i = recv(ev.data.fd, buf, sizeof(buf), 0);
}
if (i < 0)
error(1, errno, "failed to recv");
for (i = 0; i < len; ++i)
if (ev.data.fd == rcv_fd[i])
break;
if (i == len)
error(1, 0, "failed to find socket");
fprintf(stderr, "send cpu %d, receive socket %d\n", cpu_id, i);
if (cpu_id != i)
error(1, 0, "cpu id/receive socket mismatch");
}
static void test(int *rcv_fd, int len, int family, int proto)
{
struct epoll_event ev;
int epfd, cpu;
build_rcv_group(rcv_fd, len, family, proto);
attach_bpf(rcv_fd[0]);
epfd = epoll_create(1);
if (epfd < 0)
error(1, errno, "failed to create epoll");
for (cpu = 0; cpu < len; ++cpu) {
ev.events = EPOLLIN;
ev.data.fd = rcv_fd[cpu];
if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fd[cpu], &ev))
error(1, errno, "failed to register sock epoll");
}
/* Forward iterate */
for (cpu = 0; cpu < len; ++cpu) {
send_from_cpu(cpu, family, proto);
receive_on_cpu(rcv_fd, len, epfd, cpu, proto);
}
/* Reverse iterate */
for (cpu = len - 1; cpu >= 0; --cpu) {
send_from_cpu(cpu, family, proto);
receive_on_cpu(rcv_fd, len, epfd, cpu, proto);
}
/* Even cores */
for (cpu = 0; cpu < len; cpu += 2) {
send_from_cpu(cpu, family, proto);
receive_on_cpu(rcv_fd, len, epfd, cpu, proto);
}
/* Odd cores */
for (cpu = 1; cpu < len; cpu += 2) {
send_from_cpu(cpu, family, proto);
receive_on_cpu(rcv_fd, len, epfd, cpu, proto);
}
close(epfd);
for (cpu = 0; cpu < len; ++cpu)
close(rcv_fd[cpu]);
}
int main(void)
{
int *rcv_fd, cpus;
cpus = sysconf(_SC_NPROCESSORS_ONLN);
if (cpus <= 0)
error(1, errno, "failed counting cpus");
rcv_fd = calloc(cpus, sizeof(int));
if (!rcv_fd)
error(1, 0, "failed to allocate array");
fprintf(stderr, "---- IPv4 UDP ----\n");
test(rcv_fd, cpus, AF_INET, SOCK_DGRAM);
fprintf(stderr, "---- IPv6 UDP ----\n");
test(rcv_fd, cpus, AF_INET6, SOCK_DGRAM);
fprintf(stderr, "---- IPv4 TCP ----\n");
test(rcv_fd, cpus, AF_INET, SOCK_STREAM);
fprintf(stderr, "---- IPv6 TCP ----\n");
test(rcv_fd, cpus, AF_INET6, SOCK_STREAM);
free(rcv_fd);
fprintf(stderr, "SUCCESS\n");
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部