diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 8d95f4c66275317099a8eeb06ea95a5ae04cf2fb..79fb7c1be8fdf33d060eeaf6f22585c2e10fe4fa 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -1845,6 +1845,7 @@ static inline bool bpf_map_is_dev_bound(struct bpf_map *map) struct bpf_map *bpf_map_offload_map_alloc(union bpf_attr *attr); void bpf_map_offload_map_free(struct bpf_map *map); +void sock_map_destroy(struct sock *sk); #else static inline int bpf_prog_offload_init(struct bpf_prog *prog, union bpf_attr *attr) diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 1899a1c8421c076a5bd2510b60c91fc62e04a944..a83885c5bb86cbb2b86d80cf512d2c15ec1c1922 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -105,6 +105,7 @@ struct sk_psock { spinlock_t link_lock; refcount_t refcnt; void (*saved_unhash)(struct sock *sk); + void (*saved_destroy)(struct sock *sk); void (*saved_close)(struct sock *sk, long timeout); void (*saved_write_space)(struct sock *sk); struct proto *sk_proto; diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 622d93d56953d2e4f1f653a93b0fb5347825f271..925863fab5bd06c73b665a120c8d3070020750e3 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -616,6 +616,7 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) psock->eval = __SK_NONE; psock->sk_proto = prot; psock->saved_unhash = prot->unhash; + psock->saved_destroy = prot->destroy; psock->saved_close = prot->close; psock->saved_write_space = sk->sk_write_space; diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 28e518da4bcbdc00e668e42150ee79c0000fd468..85df06298c98b0aa7a8e76e98a17659ae53661b4 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -1560,6 +1560,29 @@ void sock_map_unhash(struct sock *sk) saved_unhash(sk); } +void sock_map_destroy(struct sock *sk) +{ + void (*saved_destroy)(struct sock *sk); + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock_get(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + if (sk->sk_prot->destroy) + sk->sk_prot->destroy(sk); + return; + } + + saved_destroy = psock->saved_destroy; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock, true); + sk_psock_put(sk, psock); + saved_destroy(sk); +} +EXPORT_SYMBOL_GPL(sock_map_destroy); + void sock_map_close(struct sock *sk, long timeout) { void (*saved_close)(struct sock *sk, long timeout); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index afeaf35194de6a4deb62e70add700af2a03bdc35..1a8f129fb30955314c879c8781ec2dccfe929471 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -577,6 +577,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], struct proto *base) { prot[TCP_BPF_BASE] = *base; + prot[TCP_BPF_BASE].destroy = sock_map_destroy; prot[TCP_BPF_BASE].close = sock_map_close; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;