diff --git a/include/net/udp.h b/include/net/udp.h index 741d888d0fdb4a62f899f74890d3d06a5fc5306f..05990746810eaae7802426ca9527d5fe4dddd995 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -273,6 +273,7 @@ int udp_abort(struct sock *sk, int err); int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len); int udp_push_pending_frames(struct sock *sk); void udp_flush_pending_frames(struct sock *sk); +int udp_cmsg_send(struct sock *sk, struct msghdr *msg, u16 *gso_size); void udp4_hwcsum(struct sk_buff *skb, __be32 src, __be32 dst); int udp_rcv(struct sk_buff *skb); int udp_ioctl(struct sock *sk, int cmd, unsigned long arg); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index bda022c5480bea13d216dc650d0e3369f50f922f..794aeafeb782d17ca865b3f902f591a257bdc868 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -853,6 +853,43 @@ int udp_push_pending_frames(struct sock *sk) } EXPORT_SYMBOL(udp_push_pending_frames); +static int __udp_cmsg_send(struct cmsghdr *cmsg, u16 *gso_size) +{ + switch (cmsg->cmsg_type) { + case UDP_SEGMENT: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(__u16))) + return -EINVAL; + *gso_size = *(__u16 *)CMSG_DATA(cmsg); + return 0; + default: + return -EINVAL; + } +} + +int udp_cmsg_send(struct sock *sk, struct msghdr *msg, u16 *gso_size) +{ + struct cmsghdr *cmsg; + bool need_ip = false; + int err; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + if (cmsg->cmsg_level != SOL_UDP) { + need_ip = true; + continue; + } + + err = __udp_cmsg_send(cmsg, gso_size); + if (err) + return err; + } + + return need_ip; +} +EXPORT_SYMBOL_GPL(udp_cmsg_send); + int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) { struct inet_sock *inet = inet_sk(sk); @@ -941,8 +978,11 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ipc.gso_size = up->gso_size; if (msg->msg_controllen) { - err = ip_cmsg_send(sk, msg, &ipc, sk->sk_family == AF_INET6); - if (unlikely(err)) { + err = udp_cmsg_send(sk, msg, &ipc.gso_size); + if (err > 0) + err = ip_cmsg_send(sk, msg, &ipc, + sk->sk_family == AF_INET6); + if (unlikely(err < 0)) { kfree(ipc.opt); return err; } diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 86b7dd58d4b45271388679b4e1ae120cf16c1639..6acfdd3e442b6a9a865e122c04555ad94d6af5b4 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -1276,7 +1276,10 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) opt->tot_len = sizeof(*opt); ipc6.opt = opt; - err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, &fl6, &ipc6, &sockc); + err = udp_cmsg_send(sk, msg, &ipc6.gso_size); + if (err > 0) + err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, &fl6, + &ipc6, &sockc); if (err < 0) { fl6_sock_release(flowlabel); return err;