diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index ef7031f8a304f524c3651d4c83e8053a2f0f9e8e..14d61bba0b79bd0e8cc48c24fc84970546537a4f 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -358,17 +358,22 @@ static inline void sk_psock_update_proto(struct sock *sk, static inline void sk_psock_restore_proto(struct sock *sk, struct sk_psock *psock) { - sk->sk_write_space = psock->saved_write_space; + sk->sk_prot->unhash = psock->saved_unhash; if (psock->sk_proto) { struct inet_connection_sock *icsk = inet_csk(sk); bool has_ulp = !!icsk->icsk_ulp_data; - if (has_ulp) - tcp_update_ulp(sk, psock->sk_proto); - else + if (has_ulp) { + tcp_update_ulp(sk, psock->sk_proto, + psock->saved_write_space); + } else { sk->sk_prot = psock->sk_proto; + sk->sk_write_space = psock->saved_write_space; + } psock->sk_proto = NULL; + } else { + sk->sk_write_space = psock->saved_write_space; } } diff --git a/include/net/tcp.h b/include/net/tcp.h index e460ea7f767ba627972a63a974cae80357808366..e6f48384dc7119cae21ef46c74532e0f2577d8a2 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -2147,7 +2147,8 @@ struct tcp_ulp_ops { /* initialize ulp */ int (*init)(struct sock *sk); /* update ulp */ - void (*update)(struct sock *sk, struct proto *p); + void (*update)(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)); /* cleanup ulp */ void (*release)(struct sock *sk); /* diagnostic */ @@ -2162,7 +2163,8 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type); int tcp_set_ulp(struct sock *sk, const char *name); void tcp_get_available_ulp(char *buf, size_t len); void tcp_cleanup_ulp(struct sock *sk); -void tcp_update_ulp(struct sock *sk, struct proto *p); +void tcp_update_ulp(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)); #define MODULE_ALIAS_TCP_ULP(name) \ __MODULE_INFO(alias, alias_userspace, name); \ diff --git a/net/core/filter.c b/net/core/filter.c index d22d108fc6e3901807951c58b278fba4cc125cca..538f6a735a19f017df8e10149cb578107ddc8cbb 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2231,10 +2231,10 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, /* First find the starting scatterlist element */ i = msg->sg.start; do { + offset += len; len = sk_msg_elem(msg, i)->length; if (start < offset + len) break; - offset += len; sk_msg_iter_var_next(i); } while (i != msg->sg.end); @@ -2346,7 +2346,7 @@ BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, u32, len, u64, flags) { struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge; - u32 new, i = 0, l, space, copy = 0, offset = 0; + u32 new, i = 0, l = 0, space, copy = 0, offset = 0; u8 *raw, *to, *from; struct page *page; @@ -2356,11 +2356,11 @@ BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, /* First find the starting scatterlist element */ i = msg->sg.start; do { + offset += l; l = sk_msg_elem(msg, i)->length; if (start < offset + l) break; - offset += l; sk_msg_iter_var_next(i); } while (i != msg->sg.end); @@ -2415,6 +2415,7 @@ BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, sk_msg_iter_var_next(i); sg_unmark_end(psge); + sg_unmark_end(&rsge); sk_msg_iter_next(msg, end); } @@ -2506,7 +2507,7 @@ static void sk_msg_shift_right(struct sk_msg *msg, int i) BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start, u32, len, u64, flags) { - u32 i = 0, l, space, offset = 0; + u32 i = 0, l = 0, space, offset = 0; u64 last = start + len; int pop; @@ -2516,11 +2517,11 @@ BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start, /* First find the starting scatterlist element */ i = msg->sg.start; do { + offset += l; l = sk_msg_elem(msg, i)->length; if (start < offset + l) break; - offset += l; sk_msg_iter_var_next(i); } while (i != msg->sg.end); diff --git a/net/core/skmsg.c b/net/core/skmsg.c index ded2d52276786d2845d0b82d6d0f95ccf76992cc..3866d7e20c07af5b6439af315c1d4fc6ecbf0132 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -594,6 +594,8 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy); void sk_psock_drop(struct sock *sk, struct sk_psock *psock) { + sock_owned_by_me(sk); + sk_psock_cork_free(psock); sk_psock_zap_ingress(psock); diff --git a/net/core/sock_map.c b/net/core/sock_map.c index eb114ee419b65a9c9dc282bcfe3f03ea23f49727..8998e356f423247d96cefcab054be37b1051b5b4 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -241,8 +241,11 @@ static void sock_map_free(struct bpf_map *map) struct sock *sk; sk = xchg(psk, NULL); - if (sk) + if (sk) { + lock_sock(sk); sock_map_unref(sk, psk); + release_sock(sk); + } } raw_spin_unlock_bh(&stab->lock); rcu_read_unlock(); @@ -862,7 +865,9 @@ static void sock_hash_free(struct bpf_map *map) raw_spin_lock_bh(&bucket->lock); hlist_for_each_entry_safe(elem, node, &bucket->head, node) { hlist_del_rcu(&elem->node); + lock_sock(elem->sk); sock_map_unref(elem->sk, elem); + release_sock(elem->sk); } raw_spin_unlock_bh(&bucket->lock); } diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index e6b08b5a089559ae400336c3e4fee0dc1752a91e..8a01428f80c1c4b23e2ae0568efb2aa0e0f71ccb 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -315,10 +315,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, */ delta = msg->sg.size; psock->eval = sk_psock_msg_verdict(sk, psock, msg); - if (msg->sg.size < delta) - delta -= msg->sg.size; - else - delta = 0; + delta -= msg->sg.size; } if (msg->cork_bytes && diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c index 12ab5db2b71cb5781fd9517d944138e1cecbd5dc..38d3ad1411611ef6002a9843ec7699a860f1485a 100644 --- a/net/ipv4/tcp_ulp.c +++ b/net/ipv4/tcp_ulp.c @@ -99,17 +99,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen) rcu_read_unlock(); } -void tcp_update_ulp(struct sock *sk, struct proto *proto) +void tcp_update_ulp(struct sock *sk, struct proto *proto, + void (*write_space)(struct sock *sk)) { struct inet_connection_sock *icsk = inet_csk(sk); if (!icsk->icsk_ulp_ops) { + sk->sk_write_space = write_space; sk->sk_prot = proto; return; } if (icsk->icsk_ulp_ops->update) - icsk->icsk_ulp_ops->update(sk, proto); + icsk->icsk_ulp_ops->update(sk, proto, write_space); } void tcp_cleanup_ulp(struct sock *sk) diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index dac24c7aa7d4bb45b3188dae3fb1f6030cd3c299..94774c0e5ff32706dc1a36ebc3133dadde218501 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -732,15 +732,19 @@ static int tls_init(struct sock *sk) return rc; } -static void tls_update(struct sock *sk, struct proto *p) +static void tls_update(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)) { struct tls_context *ctx; ctx = tls_get_ctx(sk); - if (likely(ctx)) + if (likely(ctx)) { + ctx->sk_write_space = write_space; ctx->sk_proto = p; - else + } else { sk->sk_prot = p; + sk->sk_write_space = write_space; + } } static int tls_get_info(const struct sock *sk, struct sk_buff *skb) diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index c6803a82b769b15a117fc444ad98861a25c16f28..159d49dab403861e78d7d174eac10c8fd57028a3 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -682,12 +682,32 @@ static int tls_push_record(struct sock *sk, int flags, split_point = msg_pl->apply_bytes; split = split_point && split_point < msg_pl->sg.size; + if (unlikely((!split && + msg_pl->sg.size + + prot->overhead_size > msg_en->sg.size) || + (split && + split_point + + prot->overhead_size > msg_en->sg.size))) { + split = true; + split_point = msg_en->sg.size; + } if (split) { rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, split_point, prot->overhead_size, &orig_end); if (rc < 0) return rc; + /* This can happen if above tls_split_open_record allocates + * a single large encryption buffer instead of two smaller + * ones. In this case adjust pointers and continue without + * split. + */ + if (!msg_pl->sg.size) { + tls_merge_open_record(sk, rec, tmp, orig_end); + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + split = false; + } sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); } @@ -709,6 +729,12 @@ static int tls_push_record(struct sock *sk, int flags, sg_mark_end(sk_msg_elem(msg_pl, i)); } + if (msg_pl->sg.end < msg_pl->sg.start) { + sg_chain(&msg_pl->sg.data[msg_pl->sg.start], + MAX_SKB_FRAGS - msg_pl->sg.start + 1, + msg_pl->sg.data); + } + i = msg_pl->sg.start; sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); @@ -783,10 +809,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, if (psock->eval == __SK_NONE) { delta = msg->sg.size; psock->eval = sk_psock_msg_verdict(sk, psock, msg); - if (delta < msg->sg.size) - delta -= msg->sg.size; - else - delta = 0; + delta -= msg->sg.size; } if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && !enospc && !full_record) {