diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h index 15d0df9434667922894936216d17bf88120d60cc..007381270ff8d7d92cb05dd31e432983edd4281a 100644 --- a/include/linux/skbuff.h +++ b/include/linux/skbuff.h @@ -2949,7 +2949,12 @@ int skb_copy_datagram_from_iter(struct sk_buff *skb, int offset, struct iov_iter *from, int len); int zerocopy_sg_from_iter(struct sk_buff *skb, struct iov_iter *frm); void skb_free_datagram(struct sock *sk, struct sk_buff *skb); -void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb); +void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len); +static inline void skb_free_datagram_locked(struct sock *sk, + struct sk_buff *skb) +{ + __skb_free_datagram_locked(sk, skb, 0); +} int skb_kill_datagram(struct sock *sk, struct sk_buff *skb, unsigned int flags); int skb_copy_bits(const struct sk_buff *skb, int offset, void *to, int len); int skb_store_bits(struct sk_buff *skb, int offset, const void *from, int len); diff --git a/include/net/sock.h b/include/net/sock.h index 310c4367ea83f6c67c63270eac46f6558db735f7..1decb7a22261d6a269dee15f50d4984a24c532d5 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -457,28 +457,32 @@ struct sock { #define SK_CAN_REUSE 1 #define SK_FORCE_REUSE 2 +int sk_set_peek_off(struct sock *sk, int val); + static inline int sk_peek_offset(struct sock *sk, int flags) { - if ((flags & MSG_PEEK) && (sk->sk_peek_off >= 0)) - return sk->sk_peek_off; - else - return 0; + if (unlikely(flags & MSG_PEEK)) { + s32 off = READ_ONCE(sk->sk_peek_off); + if (off >= 0) + return off; + } + + return 0; } static inline void sk_peek_offset_bwd(struct sock *sk, int val) { - if (sk->sk_peek_off >= 0) { - if (sk->sk_peek_off >= val) - sk->sk_peek_off -= val; - else - sk->sk_peek_off = 0; + s32 off = READ_ONCE(sk->sk_peek_off); + + if (unlikely(off >= 0)) { + off = max_t(s32, off - val, 0); + WRITE_ONCE(sk->sk_peek_off, off); } } static inline void sk_peek_offset_fwd(struct sock *sk, int val) { - if (sk->sk_peek_off >= 0) - sk->sk_peek_off += val; + sk_peek_offset_bwd(sk, -val); } /* @@ -1862,6 +1866,7 @@ void sk_reset_timer(struct sock *sk, struct timer_list *timer, void sk_stop_timer(struct sock *sk, struct timer_list *timer); +int __sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); int sock_queue_err_skb(struct sock *sk, struct sk_buff *skb); diff --git a/include/net/udp.h b/include/net/udp.h index d870ec1611c4e011256d6b301c8c6655f719e5cf..a0b0da97164c03a9ad6c796247181e242254a682 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -158,6 +158,15 @@ static inline __sum16 udp_v4_check(int len, __be32 saddr, void udp_set_csum(bool nocheck, struct sk_buff *skb, __be32 saddr, __be32 daddr, int len); +static inline void udp_csum_pull_header(struct sk_buff *skb) +{ + if (skb->ip_summed == CHECKSUM_NONE) + skb->csum = csum_partial(udp_hdr(skb), sizeof(struct udphdr), + skb->csum); + skb_pull_rcsum(skb, sizeof(struct udphdr)); + UDP_SKB_CB(skb)->cscov -= sizeof(struct udphdr); +} + struct sk_buff **udp_gro_receive(struct sk_buff **head, struct sk_buff *skb, struct udphdr *uh); int udp_gro_complete(struct sk_buff *skb, int nhoff); diff --git a/net/core/datagram.c b/net/core/datagram.c index fa9dc6450b08137078d5ecae2ebe01801083a17c..b7de71f8d5d3a5fa947fa7306fa6812e9f166da5 100644 --- a/net/core/datagram.c +++ b/net/core/datagram.c @@ -301,16 +301,19 @@ void skb_free_datagram(struct sock *sk, struct sk_buff *skb) } EXPORT_SYMBOL(skb_free_datagram); -void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb) +void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len) { bool slow; if (likely(atomic_read(&skb->users) == 1)) smp_rmb(); - else if (likely(!atomic_dec_and_test(&skb->users))) + else if (likely(!atomic_dec_and_test(&skb->users))) { + sk_peek_offset_bwd(sk, len); return; + } slow = lock_sock_fast(sk); + sk_peek_offset_bwd(sk, len); skb_orphan(skb); sk_mem_reclaim_partial(sk); unlock_sock_fast(sk, slow); @@ -318,7 +321,7 @@ void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb) /* skb is now orphaned, can be freed outside of locked section */ __kfree_skb(skb); } -EXPORT_SYMBOL(skb_free_datagram_locked); +EXPORT_SYMBOL(__skb_free_datagram_locked); /** * skb_kill_datagram - Free a datagram skbuff forcibly diff --git a/net/core/sock.c b/net/core/sock.c index 2f517ea56786127ce2e5c12684a35d9833f68d9b..2ce76e82857ff9cab856b8dad6d3139a2fdbbcf3 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -402,9 +402,8 @@ static void sock_disable_timestamp(struct sock *sk, unsigned long flags) } -int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +int __sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) { - int err; unsigned long flags; struct sk_buff_head *list = &sk->sk_receive_queue; @@ -414,10 +413,6 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) return -ENOMEM; } - err = sk_filter(sk, skb); - if (err) - return err; - if (!sk_rmem_schedule(sk, skb, skb->truesize)) { atomic_inc(&sk->sk_drops); return -ENOBUFS; @@ -440,6 +435,18 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) sk->sk_data_ready(sk); return 0; } +EXPORT_SYMBOL(__sock_queue_rcv_skb); + +int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + int err; + + err = sk_filter(sk, skb); + if (err) + return err; + + return __sock_queue_rcv_skb(sk, skb); +} EXPORT_SYMBOL(sock_queue_rcv_skb); int sk_receive_skb(struct sock *sk, struct sk_buff *skb, const int nested) @@ -2180,6 +2187,15 @@ void __sk_mem_reclaim(struct sock *sk, int amount) } EXPORT_SYMBOL(__sk_mem_reclaim); +int sk_set_peek_off(struct sock *sk, int val) +{ + if (val < 0) + return -EINVAL; + + sk->sk_peek_off = val; + return 0; +} +EXPORT_SYMBOL_GPL(sk_set_peek_off); /* * Set of default routines for initialising struct proto_ops when diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 9e481992dbaef2d8bd7e9c7b870add4d4330fd99..a38b9910af6014734a3daac37dd7914a13b424a9 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -948,6 +948,7 @@ const struct proto_ops inet_dgram_ops = { .recvmsg = inet_recvmsg, .mmap = sock_no_mmap, .sendpage = inet_sendpage, + .set_peek_off = sk_set_peek_off, #ifdef CONFIG_COMPAT .compat_setsockopt = compat_sock_common_setsockopt, .compat_getsockopt = compat_sock_common_getsockopt, diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 355bdb221057c75b9a6fda1a490f0885401505f9..d80312ddbb8acdbeccad78aec090eff712b1d876 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1294,7 +1294,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); struct sk_buff *skb; unsigned int ulen, copied; - int peeked, off = 0; + int peeked, peeking, off; int err; int is_udplite = IS_UDPLITE(sk); bool checksum_valid = false; @@ -1304,15 +1304,16 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, return ip_recv_error(sk, msg, len, addr_len); try_again: + peeking = off = sk_peek_offset(sk, flags); skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0), &peeked, &off, &err); if (!skb) - goto out; + return err; - ulen = skb->len - sizeof(struct udphdr); + ulen = skb->len; copied = len; - if (copied > ulen) - copied = ulen; + if (copied > ulen - off) + copied = ulen - off; else if (copied < ulen) msg->msg_flags |= MSG_TRUNC; @@ -1322,18 +1323,16 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, * coverage checksum (UDP-Lite), do it before the copy. */ - if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) { + if (copied < ulen || UDP_SKB_CB(skb)->partial_cov || peeking) { checksum_valid = !udp_lib_checksum_complete(skb); if (!checksum_valid) goto csum_copy_err; } if (checksum_valid || skb_csum_unnecessary(skb)) - err = skb_copy_datagram_msg(skb, sizeof(struct udphdr), - msg, copied); + err = skb_copy_datagram_msg(skb, off, msg, copied); else { - err = skb_copy_and_csum_datagram_msg(skb, sizeof(struct udphdr), - msg); + err = skb_copy_and_csum_datagram_msg(skb, off, msg); if (err == -EINVAL) goto csum_copy_err; @@ -1346,7 +1345,8 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, UDP_INC_STATS_USER(sock_net(sk), UDP_MIB_INERRORS, is_udplite); } - goto out_free; + skb_free_datagram_locked(sk, skb); + return err; } if (!peeked) @@ -1370,9 +1370,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, if (flags & MSG_TRUNC) err = ulen; -out_free: - skb_free_datagram_locked(sk, skb); -out: + __skb_free_datagram_locked(sk, skb, peeking ? -err : err); return err; csum_copy_err: @@ -1500,7 +1498,7 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) sk_incoming_cpu_update(sk); } - rc = sock_queue_rcv_skb(sk, skb); + rc = __sock_queue_rcv_skb(sk, skb); if (rc < 0) { int is_udplite = IS_UDPLITE(sk); @@ -1616,10 +1614,14 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) } } - if (rcu_access_pointer(sk->sk_filter) && - udp_lib_checksum_complete(skb)) - goto csum_error; + if (rcu_access_pointer(sk->sk_filter)) { + if (udp_lib_checksum_complete(skb)) + goto csum_error; + if (sk_filter(sk, skb)) + goto drop; + } + udp_csum_pull_header(skb); if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) { UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS, is_udplite); diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index b11c37cfd67c81c24115898e8694fd1bfb7684a1..2b78aad0d52f774799b467822319cd84fb4c723e 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -561,6 +561,7 @@ const struct proto_ops inet6_dgram_ops = { .recvmsg = inet_recvmsg, /* ok */ .mmap = sock_no_mmap, .sendpage = sock_no_sendpage, + .set_peek_off = sk_set_peek_off, #ifdef CONFIG_COMPAT .compat_setsockopt = compat_sock_common_setsockopt, .compat_getsockopt = compat_sock_common_getsockopt, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 78a7dfd127072e9205f8cc65fda51ea982b920a1..87bd7aff88b4b330015acef3c1b297cf091ffa41 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -357,7 +357,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, struct inet_sock *inet = inet_sk(sk); struct sk_buff *skb; unsigned int ulen, copied; - int peeked, off = 0; + int peeked, peeking, off; int err; int is_udplite = IS_UDPLITE(sk); bool checksum_valid = false; @@ -371,15 +371,16 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, return ipv6_recv_rxpmtu(sk, msg, len, addr_len); try_again: + peeking = off = sk_peek_offset(sk, flags); skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0), &peeked, &off, &err); if (!skb) - goto out; + return err; - ulen = skb->len - sizeof(struct udphdr); + ulen = skb->len; copied = len; - if (copied > ulen) - copied = ulen; + if (copied > ulen - off) + copied = ulen - off; else if (copied < ulen) msg->msg_flags |= MSG_TRUNC; @@ -391,17 +392,16 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, * coverage checksum (UDP-Lite), do it before the copy. */ - if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) { + if (copied < ulen || UDP_SKB_CB(skb)->partial_cov || peeking) { checksum_valid = !udp_lib_checksum_complete(skb); if (!checksum_valid) goto csum_copy_err; } if (checksum_valid || skb_csum_unnecessary(skb)) - err = skb_copy_datagram_msg(skb, sizeof(struct udphdr), - msg, copied); + err = skb_copy_datagram_msg(skb, off, msg, copied); else { - err = skb_copy_and_csum_datagram_msg(skb, sizeof(struct udphdr), msg); + err = skb_copy_and_csum_datagram_msg(skb, off, msg); if (err == -EINVAL) goto csum_copy_err; } @@ -418,7 +418,8 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, UDP_MIB_INERRORS, is_udplite); } - goto out_free; + skb_free_datagram_locked(sk, skb); + return err; } if (!peeked) { if (is_udp4) @@ -466,9 +467,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (flags & MSG_TRUNC) err = ulen; -out_free: - skb_free_datagram_locked(sk, skb); -out: + __skb_free_datagram_locked(sk, skb, peeking ? -err : err); return err; csum_copy_err: @@ -554,7 +553,7 @@ static int __udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) sk_incoming_cpu_update(sk); } - rc = sock_queue_rcv_skb(sk, skb); + rc = __sock_queue_rcv_skb(sk, skb); if (rc < 0) { int is_udplite = IS_UDPLITE(sk); @@ -648,8 +647,11 @@ int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) if (rcu_access_pointer(sk->sk_filter)) { if (udp_lib_checksum_complete(skb)) goto csum_error; + if (sk_filter(sk, skb)) + goto drop; } + udp_csum_pull_header(skb); if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) { UDP6_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS, is_udplite);