diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 22c554d3a9abf0147f03a82f3a662ffb00859fca..0ce3f61658b78b330fa634f25f74633012c6c786 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -745,7 +745,8 @@ config TCP_MD5SIG config TCP_COMP bool "TCP: Transport Layer Compression support" - depends on ZSTD_COMPRESS=y + depends on CRYPTO_ZSTD=y + select STREAM_PARSER help Enable kernel payload compression support for TCP protocol. This allows payload compression handling of the TCP protocol to be done in-kernel. diff --git a/net/ipv4/tcp_comp.c b/net/ipv4/tcp_comp.c index e85803da392480755ac42e5fed88d1752e355859..1daa6d7ad5e1387b3cbe1d878ba3fa57184bfeaf 100644 --- a/net/ipv4/tcp_comp.c +++ b/net/ipv4/tcp_comp.c @@ -9,8 +9,11 @@ #include #define TCP_COMP_MAX_PADDING 64 -#define TCP_COMP_SCRATCH_SIZE 65400 +#define TCP_COMP_SCRATCH_SIZE 65535 #define TCP_COMP_MAX_CSIZE (TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING) +#define TCP_COMP_ALLOC_ORDER get_order(65536) +#define TCP_COMP_MAX_WINDOWLOG 17 +#define TCP_COMP_MAX_INPUT (1 << TCP_COMP_MAX_WINDOWLOG) #define TCP_COMP_SEND_PENDING 1 #define ZSTD_COMP_DEFAULT_LEVEL 1 @@ -31,6 +34,20 @@ struct tcp_comp_context_tx { bool in_tcp_sendpages; }; +struct tcp_comp_context_rx { + ZSTD_DStream *dstream; + void *dworkspace; + void *plaintext_data; + void *compressed_data; + void *remaining_data; + + size_t data_offset; + struct strparser strp; + void (*saved_data_ready)(struct sock *sk); + struct sk_buff *pkt; + bool decompressed; +}; + struct tcp_comp_context { struct rcu_head rcu; @@ -38,6 +55,7 @@ struct tcp_comp_context { void (*sk_write_space)(struct sock *sk); struct tcp_comp_context_tx tx; + struct tcp_comp_context_rx rx; unsigned long flags; }; @@ -426,12 +444,344 @@ static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) return copied ? copied : err; } +static struct sk_buff *comp_wait_data(struct sock *sk, int flags, + long timeo, int *err) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + struct sk_buff *skb; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + while (!(skb = ctx->rx.pkt)) { + if (sk->sk_err) { + *err = sock_error(sk); + return NULL; + } + + if (!skb_queue_empty(&sk->sk_receive_queue)) { + __strp_unpause(&ctx->rx.strp); + if (ctx->rx.pkt) + return ctx->rx.pkt; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return NULL; + + if (sock_flag(sk, SOCK_DONE)) + return NULL; + + if ((flags & MSG_DONTWAIT) || !timeo) { + *err = -EAGAIN; + return NULL; + } + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + sk_wait_event(sk, &timeo, ctx->rx.pkt != skb, &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + + /* Handle signals */ + if (signal_pending(current)) { + *err = sock_intr_errno(timeo); + return NULL; + } + } + + return skb; +} + +static bool comp_advance_skb(struct sock *sk, struct sk_buff *skb, + unsigned int len) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + struct strp_msg *rxm = strp_msg(skb); + + if (len < rxm->full_len) { + rxm->offset += len; + rxm->full_len -= len; + return false; + } + + /* Finished with message */ + ctx->rx.pkt = NULL; + kfree_skb(skb); + __strp_unpause(&ctx->rx.strp); + + return true; +} + +static int tcp_comp_rx_context_init(struct tcp_comp_context *ctx) +{ + int dsize; + + dsize = ZSTD_DStreamWorkspaceBound(TCP_COMP_MAX_INPUT); + if (dsize <= 0) + return -EINVAL; + + ctx->rx.dworkspace = kmalloc(dsize, GFP_KERNEL); + if (!ctx->rx.dworkspace) + return -ENOMEM; + + ctx->rx.dstream = ZSTD_initDStream(TCP_COMP_MAX_INPUT, + ctx->rx.dworkspace, dsize); + if (!ctx->rx.dstream) + goto err_dstream; + + ctx->rx.plaintext_data = kvmalloc(TCP_COMP_MAX_CSIZE * 32, GFP_KERNEL); + if (!ctx->rx.plaintext_data) + goto err_dstream; + + ctx->rx.compressed_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL); + if (!ctx->rx.compressed_data) + goto err_compressed; + + ctx->rx.remaining_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL); + if (!ctx->rx.remaining_data) + goto err_remaining; + + ctx->rx.data_offset = 0; + + return 0; + +err_remaining: + kvfree(ctx->rx.compressed_data); + ctx->rx.compressed_data = NULL; +err_compressed: + kvfree(ctx->rx.plaintext_data); + ctx->rx.plaintext_data = NULL; +err_dstream: + kfree(ctx->rx.dworkspace); + ctx->rx.dworkspace = NULL; + + return -ENOMEM; +} + +static void *tcp_comp_get_rx_stream(struct sock *sk) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + if (!ctx->rx.plaintext_data) + tcp_comp_rx_context_init(ctx); + + return ctx->rx.plaintext_data; +} + +static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + const int plen = skb->len; + struct strp_msg *rxm; + ZSTD_outBuffer outbuf; + ZSTD_inBuffer inbuf; + int len; + void *to; + + to = tcp_comp_get_rx_stream(sk); + if (!to) + return -ENOSPC; + + if (skb_linearize_cow(skb)) + return -ENOMEM; + + if (plen + ctx->rx.data_offset > TCP_COMP_MAX_CSIZE) + return -ENOMEM; + + if (ctx->rx.data_offset) + memcpy(ctx->rx.compressed_data, ctx->rx.remaining_data, + ctx->rx.data_offset); + + memcpy((char *)ctx->rx.compressed_data + ctx->rx.data_offset, + skb->data, plen); + + inbuf.src = ctx->rx.compressed_data; + inbuf.pos = 0; + inbuf.size = plen + ctx->rx.data_offset; + ctx->rx.data_offset = 0; + + outbuf.dst = ctx->rx.plaintext_data; + outbuf.pos = 0; + outbuf.size = TCP_COMP_MAX_CSIZE * 32; + + while (1) { + size_t ret; + + to = outbuf.dst; + + ret = ZSTD_decompressStream(ctx->rx.dstream, &outbuf, &inbuf); + if (ZSTD_isError(ret)) + return -EIO; + + len = outbuf.pos - plen; + if (len > skb_tailroom(skb)) + len = skb_tailroom(skb); + + __skb_put(skb, len); + rxm = strp_msg(skb); + rxm->full_len += len; + + len += plen; + skb_copy_to_linear_data(skb, to, len); + + while ((to += len, outbuf.pos -= len) > 0) { + struct page *pages; + skb_frag_t *frag; + + if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS)) + return -EMSGSIZE; + + frag = skb_shinfo(skb)->frags + + skb_shinfo(skb)->nr_frags; + pages = alloc_pages(__GFP_NOWARN | GFP_KERNEL | __GFP_COMP, + TCP_COMP_ALLOC_ORDER); + + if (!pages) + return -ENOMEM; + + __skb_frag_set_page(frag, pages); + len = PAGE_SIZE << TCP_COMP_ALLOC_ORDER; + if (outbuf.pos < len) + len = outbuf.pos; + + frag->bv_offset = 0; + skb_frag_size_set(frag, len); + memcpy(skb_frag_address(frag), to, len); + + skb->truesize += len; + skb->data_len += len; + skb->len += len; + rxm->full_len += len; + skb_shinfo(skb)->nr_frags++; + } + + if (ret == 0) + break; + + if (inbuf.pos >= plen || !inbuf.pos) { + if (inbuf.pos < inbuf.size) { + memcpy((char *)ctx->rx.remaining_data, + (char *)inbuf.src + inbuf.pos, + inbuf.size - inbuf.pos); + ctx->rx.data_offset = inbuf.size - inbuf.pos; + } + break; + } + } + return 0; +} + static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, int flags, int *addr_len) { struct tcp_comp_context *ctx = comp_get_ctx(sk); + struct strp_msg *rxm; + struct sk_buff *skb; + ssize_t copied = 0; + int target, err = 0; + long timeo; + + flags |= nonblock; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + + lock_sock(sk); + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + timeo = sock_rcvtimeo(sk, flags & MSG_WAITALL); + + do { + int chunk = 0; + + skb = comp_wait_data(sk, flags, timeo, &err); + if (!skb) + goto recv_end; + + if (!ctx->rx.decompressed) { + err = tcp_comp_decompress(sk, skb); + if (err < 0) { + if (err != -ENOSPC) + tcp_comp_err_abort(sk, EBADMSG); + goto recv_end; + } + ctx->rx.decompressed = true; + } + rxm = strp_msg(skb); + + chunk = min_t(unsigned int, rxm->full_len, len); + + err = skb_copy_datagram_msg(skb, rxm->offset, msg, + chunk); + if (err < 0) + goto recv_end; + + copied += chunk; + len -= chunk; + if (likely(!(flags & MSG_PEEK))) + comp_advance_skb(sk, skb, chunk); + else + break; + + if (copied >= target && !ctx->rx.pkt) + break; + } while (len > 0); + +recv_end: + release_sock(sk); + return copied ? : err; +} + +bool comp_stream_read(const struct sock *sk) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + if (ctx->rx.pkt) + return true; + + return false; +} + +static void comp_data_ready(struct sock *sk) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + strp_data_ready(&ctx->rx.strp); +} + +static void comp_queue(struct strparser *strp, struct sk_buff *skb) +{ + struct tcp_comp_context *ctx = comp_get_ctx(strp->sk); + + ctx->rx.decompressed = false; + ctx->rx.pkt = skb; + strp_pause(strp); + ctx->rx.saved_data_ready(strp->sk); +} + +static int comp_read_size(struct strparser *strp, struct sk_buff *skb) +{ + struct strp_msg *rxm = strp_msg(skb); + + if (rxm->offset > skb->len) + return 0; - return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len); + return skb->len; +} + +void comp_setup_strp(struct sock *sk, struct tcp_comp_context *ctx) +{ + struct strp_callbacks cb; + + memset(&cb, 0, sizeof(cb)); + cb.rcv_msg = comp_queue; + cb.parse_msg = comp_read_size; + strp_init(&ctx->rx.strp, sk, &cb); + + write_lock_bh(&sk->sk_callback_lock); + ctx->rx.saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = comp_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + + strp_check_rcv(&ctx->rx.strp); } static void tcp_comp_write_space(struct sock *sk) @@ -483,6 +833,7 @@ void tcp_init_compression(struct sock *sk) rcu_assign_pointer(icsk->icsk_ulp_data, ctx); sock_set_flag(sk, SOCK_COMP); + comp_setup_strp(sk, ctx); } static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx) @@ -497,6 +848,21 @@ static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx) ctx->tx.compressed_data = NULL; } +static void tcp_comp_context_rx_free(struct tcp_comp_context *ctx) +{ + kfree(ctx->rx.dworkspace); + ctx->rx.dworkspace = NULL; + + kvfree(ctx->rx.plaintext_data); + ctx->rx.plaintext_data = NULL; + + kvfree(ctx->rx.compressed_data); + ctx->rx.compressed_data = NULL; + + kvfree(ctx->rx.remaining_data); + ctx->rx.remaining_data = NULL; +} + static void tcp_comp_context_free(struct rcu_head *head) { struct tcp_comp_context *ctx; @@ -504,6 +870,7 @@ static void tcp_comp_context_free(struct rcu_head *head) ctx = container_of(head, struct tcp_comp_context, rcu); tcp_comp_context_tx_free(ctx); + tcp_comp_context_rx_free(ctx); kfree(ctx); } @@ -515,6 +882,11 @@ void tcp_cleanup_compression(struct sock *sk) if (!ctx || !sock_flag(sk, SOCK_COMP)) return; + if (ctx->rx.pkt) { + kfree_skb(ctx->rx.pkt); + ctx->rx.pkt = NULL; + } + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); call_rcu(&ctx->rcu, tcp_comp_context_free); } @@ -524,6 +896,7 @@ int tcp_comp_init(void) tcp_prot_override = tcp_prot; tcp_prot_override.sendmsg = tcp_comp_sendmsg; tcp_prot_override.recvmsg = tcp_comp_recvmsg; + tcp_prot_override.stream_memory_read = comp_stream_read; return 0; }