diff --git a/include/net/tcp.h b/include/net/tcp.h index cee41cca6d2c55707ec08053aba9134f417ca321..9475e3da53eb857e2937383c1cdbbc5624e1624c 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -2391,6 +2391,7 @@ bool tcp_synack_comp_enabled(const struct sock *sk, const struct inet_request_sock *ireq); void tcp_init_compression(struct sock *sk); void tcp_cleanup_compression(struct sock *sk); +int tcp_comp_init(void); #else static inline bool tcp_syn_comp_enabled(const struct tcp_sock *tp) { @@ -2410,6 +2411,11 @@ static inline void tcp_init_compression(struct sock *sk) static inline void tcp_cleanup_compression(struct sock *sk) { } + +static inline int tcp_comp_init(void) +{ + return 0; +} #endif #endif /* _TCP_H */ diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 4d7bb05df4e4eb3527b8fd94edebc4346fb1b6e0..1d6efb6c6542bbb58a0ec06eb492b81c376e3898 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -4248,5 +4248,6 @@ void __init tcp_init(void) tcp_metrics_init(); BUG_ON(tcp_register_congestion_control(&tcp_reno) != 0); tcp_tasklet_init(); + tcp_comp_init(); mptcp_init(); } diff --git a/net/ipv4/tcp_comp.c b/net/ipv4/tcp_comp.c index afcd50c89e7809a7501e601c6c4c5facfa9755a3..fb90be4d9e9ccadb814901108613dff7eb5eaa15 100644 --- a/net/ipv4/tcp_comp.c +++ b/net/ipv4/tcp_comp.c @@ -12,6 +12,13 @@ static unsigned long tcp_compression_ports[65536 / 8]; unsigned long *sysctl_tcp_compression_ports = tcp_compression_ports; int sysctl_tcp_compression_local __read_mostly; +static struct proto tcp_prot_override; + +struct tcp_comp_context { + struct proto *sk_proto; + struct rcu_head rcu; +}; + static bool tcp_comp_enabled(__be32 saddr, __be32 daddr, int port) { if (!sysctl_tcp_compression_local && @@ -41,16 +48,75 @@ bool tcp_synack_comp_enabled(const struct sock *sk, ntohs(inet->inet_sport)); } +static struct tcp_comp_context *comp_get_ctx(const struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + return (__force void *)icsk->icsk_ulp_data; +} + +static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + return ctx->sk_proto->sendmsg(sk, msg, size); +} + +static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int nonblock, int flags, int *addr_len) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len); +} + void tcp_init_compression(struct sock *sk) { + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_comp_context *ctx = NULL; struct tcp_sock *tp = tcp_sk(sk); if (!tp->rx_opt.comp_ok) return; + ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); + if (!ctx) + return; + + ctx->sk_proto = sk->sk_prot; + WRITE_ONCE(sk->sk_prot, &tcp_prot_override); + + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + sock_set_flag(sk, SOCK_COMP); } +static void tcp_comp_context_free(struct rcu_head *head) +{ + struct tcp_comp_context *ctx; + + ctx = container_of(head, struct tcp_comp_context, rcu); + + kfree(ctx); +} + void tcp_cleanup_compression(struct sock *sk) { + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + if (!ctx || !sock_flag(sk, SOCK_COMP)) + return; + + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + call_rcu(&ctx->rcu, tcp_comp_context_free); +} + +int tcp_comp_init(void) +{ + tcp_prot_override = tcp_prot; + tcp_prot_override.sendmsg = tcp_comp_sendmsg; + tcp_prot_override.recvmsg = tcp_comp_recvmsg; + + return 0; }