diff --git a/include/net/tcp.h b/include/net/tcp.h index 9598871485ce3d7d36f96be9643d94e7ca85cfe0..051dc5c2802d3296f8b49d4b40911c0c22345262 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -829,7 +829,7 @@ struct tcp_congestion_ops { /* hook for packet ack accounting (optional) */ void (*pkts_acked)(struct sock *sk, u32 num_acked, s32 rtt_us); /* get info for inet_diag (optional) */ - void (*get_info)(struct sock *sk, u32 ext, struct sk_buff *skb); + int (*get_info)(struct sock *sk, u32 ext, struct sk_buff *skb); char name[TCP_CA_NAME_MAX]; struct module *owner; diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 70e8b3c308ec3bd5c3069d87558b8b24d644ed52..bb77ebdae3b31bcacbfc4fdb350d4282e400e2e8 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -111,6 +111,7 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, const struct nlmsghdr *unlh) { const struct inet_sock *inet = inet_sk(sk); + const struct tcp_congestion_ops *ca_ops; const struct inet_diag_handler *handler; int ext = req->idiag_ext; struct inet_diag_msg *r; @@ -208,16 +209,31 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, info = nla_data(attr); } - if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) - if (nla_put_string(skb, INET_DIAG_CONG, - icsk->icsk_ca_ops->name) < 0) + if (ext & (1 << (INET_DIAG_CONG - 1))) { + int err = 0; + + rcu_read_lock(); + ca_ops = READ_ONCE(icsk->icsk_ca_ops); + if (ca_ops) + err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name); + rcu_read_unlock(); + if (err < 0) goto errout; + } handler->idiag_get_info(sk, r, info); - if (sk->sk_state < TCP_TIME_WAIT && - icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info) - icsk->icsk_ca_ops->get_info(sk, ext, skb); + if (sk->sk_state < TCP_TIME_WAIT) { + int err = 0; + + rcu_read_lock(); + ca_ops = READ_ONCE(icsk->icsk_ca_ops); + if (ca_ops && ca_ops->get_info) + err = ca_ops->get_info(sk, ext, skb); + rcu_read_unlock(); + if (err < 0) + goto errout; + } out: nlmsg_end(skb, nlh); diff --git a/net/ipv4/tcp_dctcp.c b/net/ipv4/tcp_dctcp.c index b504371af74209362e03bbae0770bcbafc3e9340..4376016f7fa5cf84a3114d1551da623442c9c713 100644 --- a/net/ipv4/tcp_dctcp.c +++ b/net/ipv4/tcp_dctcp.c @@ -277,7 +277,7 @@ static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) } } -static void dctcp_get_info(struct sock *sk, u32 ext, struct sk_buff *skb) +static int dctcp_get_info(struct sock *sk, u32 ext, struct sk_buff *skb) { const struct dctcp *ca = inet_csk_ca(sk); @@ -297,8 +297,9 @@ static void dctcp_get_info(struct sock *sk, u32 ext, struct sk_buff *skb) info.dctcp_ab_tot = ca->acked_bytes_total; } - nla_put(skb, INET_DIAG_DCTCPINFO, sizeof(info), &info); + return nla_put(skb, INET_DIAG_DCTCPINFO, sizeof(info), &info); } + return 0; } static struct tcp_congestion_ops dctcp __read_mostly = { diff --git a/net/ipv4/tcp_illinois.c b/net/ipv4/tcp_illinois.c index 1d5a30a90adf6194d3b2d2215276f52c0a3f4632..67476f085e4843dacaeb1b4d6b71ecc160c99b97 100644 --- a/net/ipv4/tcp_illinois.c +++ b/net/ipv4/tcp_illinois.c @@ -300,8 +300,7 @@ static u32 tcp_illinois_ssthresh(struct sock *sk) } /* Extract info for Tcp socket info provided via netlink. */ -static void tcp_illinois_info(struct sock *sk, u32 ext, - struct sk_buff *skb) +static int tcp_illinois_info(struct sock *sk, u32 ext, struct sk_buff *skb) { const struct illinois *ca = inet_csk_ca(sk); @@ -318,8 +317,9 @@ static void tcp_illinois_info(struct sock *sk, u32 ext, do_div(t, info.tcpv_rttcnt); info.tcpv_rtt = t; } - nla_put(skb, INET_DIAG_VEGASINFO, sizeof(info), &info); + return nla_put(skb, INET_DIAG_VEGASINFO, sizeof(info), &info); } + return 0; } static struct tcp_congestion_ops tcp_illinois __read_mostly = { diff --git a/net/ipv4/tcp_vegas.c b/net/ipv4/tcp_vegas.c index a6afde666ab1284ea53c24ce80b2c287ff42073b..c71a1b8f7bde3082a6520128bb6f47d3081de8ac 100644 --- a/net/ipv4/tcp_vegas.c +++ b/net/ipv4/tcp_vegas.c @@ -286,7 +286,7 @@ static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) } /* Extract info for Tcp socket info provided via netlink. */ -void tcp_vegas_get_info(struct sock *sk, u32 ext, struct sk_buff *skb) +int tcp_vegas_get_info(struct sock *sk, u32 ext, struct sk_buff *skb) { const struct vegas *ca = inet_csk_ca(sk); if (ext & (1 << (INET_DIAG_VEGASINFO - 1))) { @@ -297,8 +297,9 @@ void tcp_vegas_get_info(struct sock *sk, u32 ext, struct sk_buff *skb) .tcpv_minrtt = ca->minRTT, }; - nla_put(skb, INET_DIAG_VEGASINFO, sizeof(info), &info); + return nla_put(skb, INET_DIAG_VEGASINFO, sizeof(info), &info); } + return 0; } EXPORT_SYMBOL_GPL(tcp_vegas_get_info); diff --git a/net/ipv4/tcp_vegas.h b/net/ipv4/tcp_vegas.h index 0531b99d8637bac1c4f19537a351f060fffa56ed..e8a6b33cc61dd7c4d58ea7d3ebd1dd1f35041de9 100644 --- a/net/ipv4/tcp_vegas.h +++ b/net/ipv4/tcp_vegas.h @@ -19,6 +19,6 @@ void tcp_vegas_init(struct sock *sk); void tcp_vegas_state(struct sock *sk, u8 ca_state); void tcp_vegas_pkts_acked(struct sock *sk, u32 cnt, s32 rtt_us); void tcp_vegas_cwnd_event(struct sock *sk, enum tcp_ca_event event); -void tcp_vegas_get_info(struct sock *sk, u32 ext, struct sk_buff *skb); +int tcp_vegas_get_info(struct sock *sk, u32 ext, struct sk_buff *skb); #endif /* __TCP_VEGAS_H */ diff --git a/net/ipv4/tcp_westwood.c b/net/ipv4/tcp_westwood.c index bb63fba47d47837b615c6bc6b730e36eead6328f..b3c57cceb9907fe9d79f33f33369a96b647b5f26 100644 --- a/net/ipv4/tcp_westwood.c +++ b/net/ipv4/tcp_westwood.c @@ -256,8 +256,7 @@ static void tcp_westwood_event(struct sock *sk, enum tcp_ca_event event) } /* Extract info for Tcp socket info provided via netlink. */ -static void tcp_westwood_info(struct sock *sk, u32 ext, - struct sk_buff *skb) +static int tcp_westwood_info(struct sock *sk, u32 ext, struct sk_buff *skb) { const struct westwood *ca = inet_csk_ca(sk); @@ -268,8 +267,9 @@ static void tcp_westwood_info(struct sock *sk, u32 ext, .tcpv_minrtt = jiffies_to_usecs(ca->rtt_min), }; - nla_put(skb, INET_DIAG_VEGASINFO, sizeof(info), &info); + return nla_put(skb, INET_DIAG_VEGASINFO, sizeof(info), &info); } + return 0; } static struct tcp_congestion_ops tcp_westwood __read_mostly = {