diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h index 7ff588ca6817afaae545382ae39a8126d9caf946..b3c28a9dfbf10e85dbf9de9b15c0bfb69e13723f 100644 --- a/include/net/inet6_hashtables.h +++ b/include/net/inet6_hashtables.h @@ -96,6 +96,8 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, const struct in6_addr *saddr, const __be16 sport, const struct in6_addr *daddr, const __be16 dport, const int dif); + +int inet6_hash(struct sock *sk); #endif /* IS_ENABLED(CONFIG_IPV6) */ #define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif) \ diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c index 9c6d0508e63a2ab7f13105cd91ea8c027c4a7557..90a8269b28d02e705e75fe054573cde446adcdf9 100644 --- a/net/dccp/ipv6.c +++ b/net/dccp/ipv6.c @@ -993,7 +993,7 @@ static struct proto dccp_v6_prot = { .sendmsg = dccp_sendmsg, .recvmsg = dccp_recvmsg, .backlog_rcv = dccp_v6_do_rcv, - .hash = inet_hash, + .hash = inet6_hash, .unhash = inet_unhash, .accept = inet_csk_accept, .get_port = inet_csk_get_port, diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c index 21ace5a2bf7c27f7e70a212cf0d94432d029112d..072653dd9c983ae06e50a62b3e5c87fdb736178f 100644 --- a/net/ipv6/inet6_hashtables.c +++ b/net/ipv6/inet6_hashtables.c @@ -274,3 +274,59 @@ int inet6_hash_connect(struct inet_timewait_death_row *death_row, __inet6_check_established); } EXPORT_SYMBOL_GPL(inet6_hash_connect); + +int inet6_hash(struct sock *sk) +{ + if (sk->sk_state != TCP_CLOSE) { + local_bh_disable(); + __inet_hash(sk, NULL); + local_bh_enable(); + } + + return 0; +} +EXPORT_SYMBOL_GPL(inet6_hash); + +/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6 + * only, and any IPv4 addresses if not IPv6 only + * match_wildcard == false: addresses must be exactly the same, i.e. + * IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY, + * and 0.0.0.0 equals to 0.0.0.0 only + */ +int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, + bool match_wildcard) +{ + const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2); + int sk2_ipv6only = inet_v6_ipv6only(sk2); + int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); + int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED; + + /* if both are mapped, treat as IPv4 */ + if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) { + if (!sk2_ipv6only) { + if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr) + return 1; + if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr) + return match_wildcard; + } + return 0; + } + + if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY) + return 1; + + if (addr_type2 == IPV6_ADDR_ANY && match_wildcard && + !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED)) + return 1; + + if (addr_type == IPV6_ADDR_ANY && match_wildcard && + !(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED)) + return 1; + + if (sk2_rcv_saddr6 && + ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6)) + return 1; + + return 0; +} +EXPORT_SYMBOL_GPL(ipv6_rcv_saddr_equal); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index 006396e31cb0dcedf359c5aaa068ab3d28f7501c..d72bcfb326d8f2e8434458790f148aa30b183796 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -1865,7 +1865,7 @@ struct proto tcpv6_prot = { .sendpage = tcp_sendpage, .backlog_rcv = tcp_v6_do_rcv, .release_cb = tcp_release_cb, - .hash = inet_hash, + .hash = inet6_hash, .unhash = inet_unhash, .get_port = inet_csk_get_port, .enter_memory_pressure = tcp_enter_memory_pressure, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 22e28a44e3c88c6e52ff59ea1c7dd8a3e964e701..ac4e7e03dded9d637f2835c6782c0b406c21b1cf 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -37,6 +37,7 @@ #include #include +#include #include #include #include @@ -77,49 +78,6 @@ static u32 udp6_ehashfn(const struct net *net, udp_ipv6_hash_secret + net_hash_mix(net)); } -/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6 - * only, and any IPv4 addresses if not IPv6 only - * match_wildcard == false: addresses must be exactly the same, i.e. - * IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY, - * and 0.0.0.0 equals to 0.0.0.0 only - */ -int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, - bool match_wildcard) -{ - const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2); - int sk2_ipv6only = inet_v6_ipv6only(sk2); - int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); - int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED; - - /* if both are mapped, treat as IPv4 */ - if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) { - if (!sk2_ipv6only) { - if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr) - return 1; - if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr) - return match_wildcard; - } - return 0; - } - - if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY) - return 1; - - if (addr_type2 == IPV6_ADDR_ANY && match_wildcard && - !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED)) - return 1; - - if (addr_type == IPV6_ADDR_ANY && match_wildcard && - !(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED)) - return 1; - - if (sk2_rcv_saddr6 && - ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6)) - return 1; - - return 0; -} - static u32 udp6_portaddr_hash(const struct net *net, const struct in6_addr *addr6, unsigned int port) diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index a2c8747d2936c305753224e7a786d67087b2cb1a..6b54ff3ff4cb8e7af49e7ee315cc78cfca148007 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -718,7 +719,7 @@ static struct proto l2tp_ip6_prot = { .sendmsg = l2tp_ip6_sendmsg, .recvmsg = l2tp_ip6_recvmsg, .backlog_rcv = l2tp_ip6_backlog_recv, - .hash = inet_hash, + .hash = inet6_hash, .unhash = inet_unhash, .obj_size = sizeof(struct l2tp_ip6_sock), #ifdef CONFIG_COMPAT