diff --git a/include/linux/udp.h b/include/linux/udp.h index ee3277593222cf314f9b030eec6ec09d0c38c4e4..247cfdcc4b08bbf377ff5819ebd02683806b0c83 100644 --- a/include/linux/udp.h +++ b/include/linux/udp.h @@ -49,11 +49,7 @@ struct udp_sock { unsigned int corkflag; /* Cork is required */ __u8 encap_type; /* Is this an Encapsulation socket? */ unsigned char no_check6_tx:1,/* Send zero UDP6 checksums on TX? */ - no_check6_rx:1,/* Allow zero UDP6 checksums on RX? */ - convert_csum:1;/* On receive, convert checksum - * unnecessary to checksum complete - * if possible. - */ + no_check6_rx:1;/* Allow zero UDP6 checksums on RX? */ /* * Following member retains the information to create a UDP header * when the socket is uncorked. @@ -102,16 +98,6 @@ static inline bool udp_get_no_check6_rx(struct sock *sk) return udp_sk(sk)->no_check6_rx; } -static inline void udp_set_convert_csum(struct sock *sk, bool val) -{ - udp_sk(sk)->convert_csum = val; -} - -static inline bool udp_get_convert_csum(struct sock *sk) -{ - return udp_sk(sk)->convert_csum; -} - #define udp_portaddr_for_each_entry(__sk, node, list) \ hlist_nulls_for_each_entry(__sk, node, list, __sk_common.skc_portaddr_node) diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h index a829b77523cf3f28704dbc593ebf112802886951..eb16c7beed1e9570168d1ccbc5a7cd9cf31dd52d 100644 --- a/include/net/inet_sock.h +++ b/include/net/inet_sock.h @@ -16,7 +16,7 @@ #ifndef _INET_SOCK_H #define _INET_SOCK_H - +#include #include #include #include @@ -184,6 +184,7 @@ struct inet_sock { mc_all:1, nodefrag:1; __u8 rcv_tos; + __u8 convert_csum; int uc_index; int mc_index; __be32 mc_addr; @@ -194,6 +195,16 @@ struct inet_sock { #define IPCORK_OPT 1 /* ip-options has been held in ipcork.opt */ #define IPCORK_ALLFRAG 2 /* always fragment (for ipv6 for now) */ +/* cmsg flags for inet */ +#define IP_CMSG_PKTINFO BIT(0) +#define IP_CMSG_TTL BIT(1) +#define IP_CMSG_TOS BIT(2) +#define IP_CMSG_RECVOPTS BIT(3) +#define IP_CMSG_RETOPTS BIT(4) +#define IP_CMSG_PASSSEC BIT(5) +#define IP_CMSG_ORIGDSTADDR BIT(6) +#define IP_CMSG_CHECKSUM BIT(7) + static inline struct inet_sock *inet_sk(const struct sock *sk) { return (struct inet_sock *)sk; @@ -250,4 +261,20 @@ static inline __u8 inet_sk_flowi_flags(const struct sock *sk) return flags; } +static inline void inet_inc_convert_csum(struct sock *sk) +{ + inet_sk(sk)->convert_csum++; +} + +static inline void inet_dec_convert_csum(struct sock *sk) +{ + if (inet_sk(sk)->convert_csum > 0) + inet_sk(sk)->convert_csum--; +} + +static inline bool inet_get_convert_csum(struct sock *sk) +{ + return !!inet_sk(sk)->convert_csum; +} + #endif /* _INET_SOCK_H */ diff --git a/include/net/ip.h b/include/net/ip.h index 0bb620702929e7ad3b48f7aa40e5c73df3638141..0e5a0bae187f6ae81c32f30042e08b2780390237 100644 --- a/include/net/ip.h +++ b/include/net/ip.h @@ -537,7 +537,7 @@ int ip_options_rcv_srr(struct sk_buff *skb); */ void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb); -void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb); +void ip_cmsg_recv_offset(struct msghdr *msg, struct sk_buff *skb, int offset); int ip_cmsg_send(struct net *net, struct msghdr *msg, struct ipcm_cookie *ipc, bool allow_ipv6); int ip_setsockopt(struct sock *sk, int level, int optname, char __user *optval, @@ -557,6 +557,11 @@ void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, __be16 port, void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 dport, u32 info); +static inline void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb) +{ + ip_cmsg_recv_offset(msg, skb, 0); +} + bool icmp_global_allow(void); extern int sysctl_icmp_msgs_per_sec; extern int sysctl_icmp_msgs_burst; diff --git a/include/uapi/linux/in.h b/include/uapi/linux/in.h index c33a65e3d62c85d104d38ab082d997df13cd4c0b..589ced069e8a1a68a9b1c9336517d66a675a75d2 100644 --- a/include/uapi/linux/in.h +++ b/include/uapi/linux/in.h @@ -109,6 +109,7 @@ struct in_addr { #define IP_MINTTL 21 #define IP_NODEFRAG 22 +#define IP_CHECKSUM 23 /* IP_MTU_DISCOVER values */ #define IP_PMTUDISC_DONT 0 /* Never send DF frames */ diff --git a/net/ipv4/fou.c b/net/ipv4/fou.c index b986298a7ba39908290ccd808a24947d776b91b4..2197c36f722fa0ae83b3dc4975a4fecbbf01497f 100644 --- a/net/ipv4/fou.c +++ b/net/ipv4/fou.c @@ -490,7 +490,7 @@ static int fou_create(struct net *net, struct fou_cfg *cfg, sk->sk_user_data = fou; fou->sock = sock; - udp_set_convert_csum(sk, true); + inet_inc_convert_csum(sk); sk->sk_allocation = GFP_ATOMIC; diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c index 8a89c738b7a3b43407293f521bd6d7e009ee7c80..a317797b3cd020472f3490a82bf7c749b08b8b9e 100644 --- a/net/ipv4/ip_sockglue.c +++ b/net/ipv4/ip_sockglue.c @@ -37,6 +37,7 @@ #include #include #include +#include #if IS_ENABLED(CONFIG_IPV6) #include #endif @@ -45,14 +46,6 @@ #include #include -#define IP_CMSG_PKTINFO 1 -#define IP_CMSG_TTL 2 -#define IP_CMSG_TOS 4 -#define IP_CMSG_RECVOPTS 8 -#define IP_CMSG_RETOPTS 16 -#define IP_CMSG_PASSSEC 32 -#define IP_CMSG_ORIGDSTADDR 64 - /* * SOL_IP control messages. */ @@ -104,6 +97,20 @@ static void ip_cmsg_recv_retopts(struct msghdr *msg, struct sk_buff *skb) put_cmsg(msg, SOL_IP, IP_RETOPTS, opt->optlen, opt->__data); } +static void ip_cmsg_recv_checksum(struct msghdr *msg, struct sk_buff *skb, + int offset) +{ + __wsum csum = skb->csum; + + if (skb->ip_summed != CHECKSUM_COMPLETE) + return; + + if (offset != 0) + csum = csum_sub(csum, csum_partial(skb->data, offset, 0)); + + put_cmsg(msg, SOL_IP, IP_CHECKSUM, sizeof(__wsum), &csum); +} + static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb) { char *secdata; @@ -144,47 +151,73 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) put_cmsg(msg, SOL_IP, IP_ORIGDSTADDR, sizeof(sin), &sin); } -void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb) +void ip_cmsg_recv_offset(struct msghdr *msg, struct sk_buff *skb, + int offset) { struct inet_sock *inet = inet_sk(skb->sk); unsigned int flags = inet->cmsg_flags; /* Ordered by supposed usage frequency */ - if (flags & 1) + if (flags & IP_CMSG_PKTINFO) { ip_cmsg_recv_pktinfo(msg, skb); - if ((flags >>= 1) == 0) - return; - if (flags & 1) + flags &= ~IP_CMSG_PKTINFO; + if (!flags) + return; + } + + if (flags & IP_CMSG_TTL) { ip_cmsg_recv_ttl(msg, skb); - if ((flags >>= 1) == 0) - return; - if (flags & 1) + flags &= ~IP_CMSG_TTL; + if (!flags) + return; + } + + if (flags & IP_CMSG_TOS) { ip_cmsg_recv_tos(msg, skb); - if ((flags >>= 1) == 0) - return; - if (flags & 1) + flags &= ~IP_CMSG_TOS; + if (!flags) + return; + } + + if (flags & IP_CMSG_RECVOPTS) { ip_cmsg_recv_opts(msg, skb); - if ((flags >>= 1) == 0) - return; - if (flags & 1) + flags &= ~IP_CMSG_RECVOPTS; + if (!flags) + return; + } + + if (flags & IP_CMSG_RETOPTS) { ip_cmsg_recv_retopts(msg, skb); - if ((flags >>= 1) == 0) - return; - if (flags & 1) + flags &= ~IP_CMSG_RETOPTS; + if (!flags) + return; + } + + if (flags & IP_CMSG_PASSSEC) { ip_cmsg_recv_security(msg, skb); - if ((flags >>= 1) == 0) - return; - if (flags & 1) + flags &= ~IP_CMSG_PASSSEC; + if (!flags) + return; + } + + if (flags & IP_CMSG_ORIGDSTADDR) { ip_cmsg_recv_dstaddr(msg, skb); + flags &= ~IP_CMSG_ORIGDSTADDR; + if (!flags) + return; + } + + if (flags & IP_CMSG_CHECKSUM) + ip_cmsg_recv_checksum(msg, skb, offset); } -EXPORT_SYMBOL(ip_cmsg_recv); +EXPORT_SYMBOL(ip_cmsg_recv_offset); int ip_cmsg_send(struct net *net, struct msghdr *msg, struct ipcm_cookie *ipc, bool allow_ipv6) @@ -522,6 +555,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, case IP_MULTICAST_ALL: case IP_MULTICAST_LOOP: case IP_RECVORIGDSTADDR: + case IP_CHECKSUM: if (optlen >= sizeof(int)) { if (get_user(val, (int __user *) optval)) return -EFAULT; @@ -619,6 +653,19 @@ static int do_ip_setsockopt(struct sock *sk, int level, else inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR; break; + case IP_CHECKSUM: + if (val) { + if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) { + inet_inc_convert_csum(sk); + inet->cmsg_flags |= IP_CMSG_CHECKSUM; + } + } else { + if (inet->cmsg_flags & IP_CMSG_CHECKSUM) { + inet_dec_convert_csum(sk); + inet->cmsg_flags &= ~IP_CMSG_CHECKSUM; + } + } + break; case IP_TOS: /* This sets both TOS and Precedence */ if (sk->sk_type == SOCK_STREAM) { val &= ~INET_ECN_MASK; @@ -1222,6 +1269,9 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, case IP_RECVORIGDSTADDR: val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0; break; + case IP_CHECKSUM: + val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0; + break; case IP_TOS: val = inet->tos; break; diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 13b4dcf86ef610d1fcc1b26f7f69f5a6bbd31686..97ef1f8b7be81ed7d06c599b4158db0507afde44 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1329,7 +1329,7 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg, *addr_len = sizeof(*sin); } if (inet->cmsg_flags) - ip_cmsg_recv(msg, skb); + ip_cmsg_recv_offset(msg, skb, sizeof(struct udphdr)); err = copied; if (flags & MSG_TRUNC) @@ -1806,7 +1806,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, if (sk != NULL) { int ret; - if (udp_sk(sk)->convert_csum && uh->check && !IS_UDPLITE(sk)) + if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk)) skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check, inet_compute_pseudo); diff --git a/net/ipv4/udp_tunnel.c b/net/ipv4/udp_tunnel.c index 1671263e5fa0eae2e6a7ebad40f912cae002bc53..9996e63ed304888e471e6280b3f9db1324be9334 100644 --- a/net/ipv4/udp_tunnel.c +++ b/net/ipv4/udp_tunnel.c @@ -63,7 +63,7 @@ void setup_udp_tunnel_sock(struct net *net, struct socket *sock, inet_sk(sk)->mc_loop = 0; /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */ - udp_set_convert_csum(sk, true); + inet_inc_convert_csum(sk); rcu_assign_sk_user_data(sk, cfg->sk_user_data); diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 189dc4ae3ecac1b140a7208c4b6de0b956e0b710..e41f017cd479c04ad5876ca2e978138af181a9d9 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -909,7 +909,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, goto csum_error; } - if (udp_sk(sk)->convert_csum && uh->check && !IS_UDPLITE(sk)) + if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk)) skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check, ip6_compute_pseudo);