diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 5dd5569f89bf5aa15bcce26af05bd28037907afe..4ee4fe4368474b0c5362218c356c708f8098d091 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -765,7 +765,7 @@ static struct sk_psock *sk_psock_from_strp(struct strparser *strp) return container_of(parser, struct sk_psock, parser); } -static void sk_psock_skb_redirect(struct sk_buff *skb) +static int sk_psock_skb_redirect(struct sk_buff *skb) { struct sk_psock *psock_other; struct sock *sk_other; @@ -776,7 +776,7 @@ static void sk_psock_skb_redirect(struct sk_buff *skb) */ if (unlikely(!sk_other)) { kfree_skb(skb); - return; + return -EIO; } psock_other = sk_psock(sk_other); /* This error indicates the socket is being torn down or had another @@ -786,11 +786,12 @@ static void sk_psock_skb_redirect(struct sk_buff *skb) if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) { kfree_skb(skb); - return; + return -EIO; } skb_queue_tail(&psock_other->ingress_skb, skb); schedule_work(&psock_other->work); + return 0; } static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict) @@ -826,15 +827,16 @@ int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb) } EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read); -static void sk_psock_verdict_apply(struct sk_psock *psock, - struct sk_buff *skb, int verdict) +static int sk_psock_verdict_apply(struct sk_psock *psock, struct sk_buff *skb, + int verdict) { struct tcp_skb_cb *tcp; struct sock *sk_other; - int err = -EIO; + int err = 0; switch (verdict) { case __SK_PASS: + err = -EIO; sk_other = psock->sk; if (sock_flag(sk_other, SOCK_DEAD) || !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { @@ -859,13 +861,15 @@ static void sk_psock_verdict_apply(struct sk_psock *psock, } break; case __SK_REDIRECT: - sk_psock_skb_redirect(skb); + err = sk_psock_skb_redirect(skb); break; case __SK_DROP: default: out_free: kfree_skb(skb); } + + return err; } static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) @@ -967,7 +971,8 @@ static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb, ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); skb->sk = NULL; } - sk_psock_verdict_apply(psock, skb, ret); + if (sk_psock_verdict_apply(psock, skb, ret) < 0) + len = 0; out: rcu_read_unlock(); return len;