diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 153b6dec9b6ae1da1bd98bf1f2bdbc21629c58df..48f4b645193b7d8ec3882bbc73ddb07e212a069c 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -278,7 +278,8 @@ static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start) static inline struct sk_psock *sk_psock(const struct sock *sk) { - return rcu_dereference_sk_user_data(sk); + return __rcu_dereference_sk_user_data_with_flags(sk, + SK_USER_DATA_PSOCK); } static inline void sk_psock_set_state(struct sk_psock *psock, diff --git a/include/net/sock.h b/include/net/sock.h index a7273b28918846233d93665f0fcd4b0a18b90d23..05a1bbdf58054d149f1ff8be8a4ffae827bb9be9 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -545,14 +545,26 @@ enum sk_pacing { SK_PACING_FQ = 2, }; -/* Pointer stored in sk_user_data might not be suitable for copying - * when cloning the socket. For instance, it can point to a reference - * counted object. sk_user_data bottom bit is set if pointer must not - * be copied. +/* flag bits in sk_user_data + * + * - SK_USER_DATA_NOCOPY: Pointer stored in sk_user_data might + * not be suitable for copying when cloning the socket. For instance, + * it can point to a reference counted object. sk_user_data bottom + * bit is set if pointer must not be copied. + * + * - SK_USER_DATA_BPF: Mark whether sk_user_data field is + * managed/owned by a BPF reuseport array. This bit should be set + * when sk_user_data's sk is added to the bpf's reuseport_array. + * + * - SK_USER_DATA_PSOCK: Mark whether pointer stored in + * sk_user_data points to psock type. This bit should be set + * when sk_user_data is assigned to a psock object. */ #define SK_USER_DATA_NOCOPY 1UL -#define SK_USER_DATA_BPF 2UL /* Managed by BPF */ -#define SK_USER_DATA_PTRMASK ~(SK_USER_DATA_NOCOPY | SK_USER_DATA_BPF) +#define SK_USER_DATA_BPF 2UL +#define SK_USER_DATA_PSOCK 4UL +#define SK_USER_DATA_PTRMASK ~(SK_USER_DATA_NOCOPY | SK_USER_DATA_BPF |\ + SK_USER_DATA_PSOCK) /** * sk_user_data_is_nocopy - Test if sk_user_data pointer must not be copied @@ -565,24 +577,40 @@ static inline bool sk_user_data_is_nocopy(const struct sock *sk) #define __sk_user_data(sk) ((*((void __rcu **)&(sk)->sk_user_data))) +/** + * __rcu_dereference_sk_user_data_with_flags - return the pointer + * only if argument flags all has been set in sk_user_data. Otherwise + * return NULL + * + * @sk: socket + * @flags: flag bits + */ +static inline void * +__rcu_dereference_sk_user_data_with_flags(const struct sock *sk, + uintptr_t flags) +{ + uintptr_t sk_user_data = (uintptr_t)rcu_dereference(__sk_user_data(sk)); + + WARN_ON_ONCE(flags & SK_USER_DATA_PTRMASK); + + if ((sk_user_data & flags) == flags) + return (void *)(sk_user_data & SK_USER_DATA_PTRMASK); + return NULL; +} + #define rcu_dereference_sk_user_data(sk) \ + __rcu_dereference_sk_user_data_with_flags(sk, 0) +#define __rcu_assign_sk_user_data_with_flags(sk, ptr, flags) \ ({ \ - void *__tmp = rcu_dereference(__sk_user_data((sk))); \ - (void *)((uintptr_t)__tmp & SK_USER_DATA_PTRMASK); \ -}) -#define rcu_assign_sk_user_data(sk, ptr) \ -({ \ - uintptr_t __tmp = (uintptr_t)(ptr); \ - WARN_ON_ONCE(__tmp & ~SK_USER_DATA_PTRMASK); \ - rcu_assign_pointer(__sk_user_data((sk)), __tmp); \ -}) -#define rcu_assign_sk_user_data_nocopy(sk, ptr) \ -({ \ - uintptr_t __tmp = (uintptr_t)(ptr); \ - WARN_ON_ONCE(__tmp & ~SK_USER_DATA_PTRMASK); \ + uintptr_t __tmp1 = (uintptr_t)(ptr), \ + __tmp2 = (uintptr_t)(flags); \ + WARN_ON_ONCE(__tmp1 & ~SK_USER_DATA_PTRMASK); \ + WARN_ON_ONCE(__tmp2 & SK_USER_DATA_PTRMASK); \ rcu_assign_pointer(__sk_user_data((sk)), \ - __tmp | SK_USER_DATA_NOCOPY); \ + __tmp1 | __tmp2); \ }) +#define rcu_assign_sk_user_data(sk, ptr) \ + __rcu_assign_sk_user_data_with_flags(sk, ptr, 0) static inline struct net *sock_net(const struct sock *sk) diff --git a/kernel/bpf/reuseport_array.c b/kernel/bpf/reuseport_array.c index e2618fb5870e758ca7c9787882ce9b11c823f730..85fa9dbfa8bf88be026494b1a663912b968b1c26 100644 --- a/kernel/bpf/reuseport_array.c +++ b/kernel/bpf/reuseport_array.c @@ -21,14 +21,11 @@ static struct reuseport_array *reuseport_array(struct bpf_map *map) /* The caller must hold the reuseport_lock */ void bpf_sk_reuseport_detach(struct sock *sk) { - uintptr_t sk_user_data; + struct sock __rcu **socks; write_lock_bh(&sk->sk_callback_lock); - sk_user_data = (uintptr_t)sk->sk_user_data; - if (sk_user_data & SK_USER_DATA_BPF) { - struct sock __rcu **socks; - - socks = (void *)(sk_user_data & SK_USER_DATA_PTRMASK); + socks = __rcu_dereference_sk_user_data_with_flags(sk, SK_USER_DATA_BPF); + if (socks) { WRITE_ONCE(sk->sk_user_data, NULL); /* * Do not move this NULL assignment outside of diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 81627892bdd44693da2e942007447c685aa25bc8..57e942a6431af98bb635378d97895717e4e9dbac 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -739,7 +739,9 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); refcount_set(&psock->refcnt, 1); - rcu_assign_sk_user_data_nocopy(sk, psock); + __rcu_assign_sk_user_data_with_flags(sk, psock, + SK_USER_DATA_NOCOPY | + SK_USER_DATA_PSOCK); sock_hold(sk); out: