diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h index 00cbb4384c795b470777bd501cb65c914df50f82..9e34c877a77093ff1257e18941949a12391363d6 100644 --- a/include/net/inet6_hashtables.h +++ b/include/net/inet6_hashtables.h @@ -96,14 +96,15 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo, const __be16 sport, const __be16 dport) { - struct sock *sk; + struct sock *sk = skb_steal_sock(skb); - if (unlikely(sk = skb_steal_sock(skb))) + if (sk) return sk; - else return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, - &ipv6_hdr(skb)->saddr, sport, - &ipv6_hdr(skb)->daddr, ntohs(dport), - inet6_iif(skb)); + + return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, + &ipv6_hdr(skb)->saddr, sport, + &ipv6_hdr(skb)->daddr, ntohs(dport), + inet6_iif(skb)); } extern struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, diff --git a/include/net/protocol.h b/include/net/protocol.h index 057f2d3155673cf302171b2c57dc0366505c524e..929528c73fe8454a1dd34da7efbbc6d1fd8677fd 100644 --- a/include/net/protocol.h +++ b/include/net/protocol.h @@ -52,6 +52,8 @@ struct net_protocol { #if IS_ENABLED(CONFIG_IPV6) struct inet6_protocol { + void (*early_demux)(struct sk_buff *skb); + int (*handler)(struct sk_buff *skb); void (*err_handler)(struct sk_buff *skb, diff --git a/net/ipv4/ip_input.c b/net/ipv4/ip_input.c index bda8cac2ae9184dea0193962a924dcf58e0921b0..981ff1eef28cc9a1b24b0b50f58caa32fbcdd0d1 100644 --- a/net/ipv4/ip_input.c +++ b/net/ipv4/ip_input.c @@ -314,6 +314,7 @@ static inline bool ip_rcv_options(struct sk_buff *skb) } int sysctl_ip_early_demux __read_mostly = 1; +EXPORT_SYMBOL(sysctl_ip_early_demux); static int ip_rcv_finish(struct sk_buff *skb) { diff --git a/net/ipv6/ip6_input.c b/net/ipv6/ip6_input.c index 5ab923e51af3d4232afe9ea56af0933b7dcb8b9c..47975e363fcdec47ad3c12f5097688e839772167 100644 --- a/net/ipv6/ip6_input.c +++ b/net/ipv6/ip6_input.c @@ -47,9 +47,18 @@ -inline int ip6_rcv_finish( struct sk_buff *skb) +int ip6_rcv_finish(struct sk_buff *skb) { - if (skb_dst(skb) == NULL) + if (sysctl_ip_early_demux && !skb_dst(skb)) { + const struct inet6_protocol *ipprot; + + rcu_read_lock(); + ipprot = rcu_dereference(inet6_protos[ipv6_hdr(skb)->nexthdr]); + if (ipprot && ipprot->early_demux) + ipprot->early_demux(skb); + rcu_read_unlock(); + } + if (!skb_dst(skb)) ip6_route_input(skb); return dst_input(skb); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index f49476e2d8845092d8d0e30dcf6ec9a7db7a6106..221224e72507cf462021c8ff330e355f067e7efb 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -1674,6 +1674,43 @@ static int tcp_v6_rcv(struct sk_buff *skb) goto discard_it; } +static void tcp_v6_early_demux(struct sk_buff *skb) +{ + const struct ipv6hdr *hdr; + const struct tcphdr *th; + struct sock *sk; + + if (skb->pkt_type != PACKET_HOST) + return; + + if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct tcphdr))) + return; + + hdr = ipv6_hdr(skb); + th = tcp_hdr(skb); + + if (th->doff < sizeof(struct tcphdr) / 4) + return; + + sk = __inet6_lookup_established(dev_net(skb->dev), &tcp_hashinfo, + &hdr->saddr, th->source, + &hdr->daddr, ntohs(th->dest), + inet6_iif(skb)); + if (sk) { + skb->sk = sk; + skb->destructor = sock_edemux; + if (sk->sk_state != TCP_TIME_WAIT) { + struct dst_entry *dst = sk->sk_rx_dst; + struct inet_sock *icsk = inet_sk(sk); + if (dst) + dst = dst_check(dst, 0); + if (dst && + icsk->rx_dst_ifindex == inet6_iif(skb)) + skb_dst_set_noref(skb, dst); + } + } +} + static struct timewait_sock_ops tcp6_timewait_sock_ops = { .twsk_obj_size = sizeof(struct tcp6_timewait_sock), .twsk_unique = tcp_twsk_unique, @@ -1984,6 +2021,7 @@ struct proto tcpv6_prot = { }; static const struct inet6_protocol tcpv6_protocol = { + .early_demux = tcp_v6_early_demux, .handler = tcp_v6_rcv, .err_handler = tcp_v6_err, .gso_send_check = tcp_v6_gso_send_check,