提交 fbcb4859 编写于 作者: W Wang Yufen 提交者: Zheng Zengkai

tcp_comp: implement recvmsg for tcp compression

hulk inclusion
category: feature
bugzilla: https://gitee.com/openeuler/kernel/issues/I4PNEK
CVE: NA

-------------------------------------------------

This patch implement software level compression for
receiving tcp messages. The compressed TCP payload will be
decompressed after receive.
Signed-off-by: NWang Yufen <wangyufen@huawei.com>
Signed-off-by: NYang Yingliang <yangyingliang@huawei.com>
Signed-off-by: NLu Wei <luwei32@huawei.com>
Reviewed-by: NWei Yongjun <weiyongjun1@huawei.com>
Signed-off-by: NZheng Zengkai <zhengzengkai@huawei.com>
上级 e9ce37bb
...@@ -745,7 +745,8 @@ config TCP_MD5SIG ...@@ -745,7 +745,8 @@ config TCP_MD5SIG
config TCP_COMP config TCP_COMP
bool "TCP: Transport Layer Compression support" bool "TCP: Transport Layer Compression support"
depends on ZSTD_COMPRESS=y depends on CRYPTO_ZSTD=y
select STREAM_PARSER
help help
Enable kernel payload compression support for TCP protocol. This allows Enable kernel payload compression support for TCP protocol. This allows
payload compression handling of the TCP protocol to be done in-kernel. payload compression handling of the TCP protocol to be done in-kernel.
......
...@@ -9,8 +9,11 @@ ...@@ -9,8 +9,11 @@
#include <linux/zstd.h> #include <linux/zstd.h>
#define TCP_COMP_MAX_PADDING 64 #define TCP_COMP_MAX_PADDING 64
#define TCP_COMP_SCRATCH_SIZE 65400 #define TCP_COMP_SCRATCH_SIZE 65535
#define TCP_COMP_MAX_CSIZE (TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING) #define TCP_COMP_MAX_CSIZE (TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING)
#define TCP_COMP_ALLOC_ORDER get_order(65536)
#define TCP_COMP_MAX_WINDOWLOG 17
#define TCP_COMP_MAX_INPUT (1 << TCP_COMP_MAX_WINDOWLOG)
#define TCP_COMP_SEND_PENDING 1 #define TCP_COMP_SEND_PENDING 1
#define ZSTD_COMP_DEFAULT_LEVEL 1 #define ZSTD_COMP_DEFAULT_LEVEL 1
...@@ -31,6 +34,20 @@ struct tcp_comp_context_tx { ...@@ -31,6 +34,20 @@ struct tcp_comp_context_tx {
bool in_tcp_sendpages; bool in_tcp_sendpages;
}; };
struct tcp_comp_context_rx {
ZSTD_DStream *dstream;
void *dworkspace;
void *plaintext_data;
void *compressed_data;
void *remaining_data;
size_t data_offset;
struct strparser strp;
void (*saved_data_ready)(struct sock *sk);
struct sk_buff *pkt;
bool decompressed;
};
struct tcp_comp_context { struct tcp_comp_context {
struct rcu_head rcu; struct rcu_head rcu;
...@@ -38,6 +55,7 @@ struct tcp_comp_context { ...@@ -38,6 +55,7 @@ struct tcp_comp_context {
void (*sk_write_space)(struct sock *sk); void (*sk_write_space)(struct sock *sk);
struct tcp_comp_context_tx tx; struct tcp_comp_context_tx tx;
struct tcp_comp_context_rx rx;
unsigned long flags; unsigned long flags;
}; };
...@@ -426,12 +444,344 @@ static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -426,12 +444,344 @@ static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
return copied ? copied : err; return copied ? copied : err;
} }
static struct sk_buff *comp_wait_data(struct sock *sk, int flags,
long timeo, int *err)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct sk_buff *skb;
DEFINE_WAIT_FUNC(wait, woken_wake_function);
while (!(skb = ctx->rx.pkt)) {
if (sk->sk_err) {
*err = sock_error(sk);
return NULL;
}
if (!skb_queue_empty(&sk->sk_receive_queue)) {
__strp_unpause(&ctx->rx.strp);
if (ctx->rx.pkt)
return ctx->rx.pkt;
}
if (sk->sk_shutdown & RCV_SHUTDOWN)
return NULL;
if (sock_flag(sk, SOCK_DONE))
return NULL;
if ((flags & MSG_DONTWAIT) || !timeo) {
*err = -EAGAIN;
return NULL;
}
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo, ctx->rx.pkt != skb, &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
/* Handle signals */
if (signal_pending(current)) {
*err = sock_intr_errno(timeo);
return NULL;
}
}
return skb;
}
static bool comp_advance_skb(struct sock *sk, struct sk_buff *skb,
unsigned int len)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct strp_msg *rxm = strp_msg(skb);
if (len < rxm->full_len) {
rxm->offset += len;
rxm->full_len -= len;
return false;
}
/* Finished with message */
ctx->rx.pkt = NULL;
kfree_skb(skb);
__strp_unpause(&ctx->rx.strp);
return true;
}
static int tcp_comp_rx_context_init(struct tcp_comp_context *ctx)
{
int dsize;
dsize = ZSTD_DStreamWorkspaceBound(TCP_COMP_MAX_INPUT);
if (dsize <= 0)
return -EINVAL;
ctx->rx.dworkspace = kmalloc(dsize, GFP_KERNEL);
if (!ctx->rx.dworkspace)
return -ENOMEM;
ctx->rx.dstream = ZSTD_initDStream(TCP_COMP_MAX_INPUT,
ctx->rx.dworkspace, dsize);
if (!ctx->rx.dstream)
goto err_dstream;
ctx->rx.plaintext_data = kvmalloc(TCP_COMP_MAX_CSIZE * 32, GFP_KERNEL);
if (!ctx->rx.plaintext_data)
goto err_dstream;
ctx->rx.compressed_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL);
if (!ctx->rx.compressed_data)
goto err_compressed;
ctx->rx.remaining_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL);
if (!ctx->rx.remaining_data)
goto err_remaining;
ctx->rx.data_offset = 0;
return 0;
err_remaining:
kvfree(ctx->rx.compressed_data);
ctx->rx.compressed_data = NULL;
err_compressed:
kvfree(ctx->rx.plaintext_data);
ctx->rx.plaintext_data = NULL;
err_dstream:
kfree(ctx->rx.dworkspace);
ctx->rx.dworkspace = NULL;
return -ENOMEM;
}
static void *tcp_comp_get_rx_stream(struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
if (!ctx->rx.plaintext_data)
tcp_comp_rx_context_init(ctx);
return ctx->rx.plaintext_data;
}
static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
const int plen = skb->len;
struct strp_msg *rxm;
ZSTD_outBuffer outbuf;
ZSTD_inBuffer inbuf;
int len;
void *to;
to = tcp_comp_get_rx_stream(sk);
if (!to)
return -ENOSPC;
if (skb_linearize_cow(skb))
return -ENOMEM;
if (plen + ctx->rx.data_offset > TCP_COMP_MAX_CSIZE)
return -ENOMEM;
if (ctx->rx.data_offset)
memcpy(ctx->rx.compressed_data, ctx->rx.remaining_data,
ctx->rx.data_offset);
memcpy((char *)ctx->rx.compressed_data + ctx->rx.data_offset,
skb->data, plen);
inbuf.src = ctx->rx.compressed_data;
inbuf.pos = 0;
inbuf.size = plen + ctx->rx.data_offset;
ctx->rx.data_offset = 0;
outbuf.dst = ctx->rx.plaintext_data;
outbuf.pos = 0;
outbuf.size = TCP_COMP_MAX_CSIZE * 32;
while (1) {
size_t ret;
to = outbuf.dst;
ret = ZSTD_decompressStream(ctx->rx.dstream, &outbuf, &inbuf);
if (ZSTD_isError(ret))
return -EIO;
len = outbuf.pos - plen;
if (len > skb_tailroom(skb))
len = skb_tailroom(skb);
__skb_put(skb, len);
rxm = strp_msg(skb);
rxm->full_len += len;
len += plen;
skb_copy_to_linear_data(skb, to, len);
while ((to += len, outbuf.pos -= len) > 0) {
struct page *pages;
skb_frag_t *frag;
if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
return -EMSGSIZE;
frag = skb_shinfo(skb)->frags +
skb_shinfo(skb)->nr_frags;
pages = alloc_pages(__GFP_NOWARN | GFP_KERNEL | __GFP_COMP,
TCP_COMP_ALLOC_ORDER);
if (!pages)
return -ENOMEM;
__skb_frag_set_page(frag, pages);
len = PAGE_SIZE << TCP_COMP_ALLOC_ORDER;
if (outbuf.pos < len)
len = outbuf.pos;
frag->bv_offset = 0;
skb_frag_size_set(frag, len);
memcpy(skb_frag_address(frag), to, len);
skb->truesize += len;
skb->data_len += len;
skb->len += len;
rxm->full_len += len;
skb_shinfo(skb)->nr_frags++;
}
if (ret == 0)
break;
if (inbuf.pos >= plen || !inbuf.pos) {
if (inbuf.pos < inbuf.size) {
memcpy((char *)ctx->rx.remaining_data,
(char *)inbuf.src + inbuf.pos,
inbuf.size - inbuf.pos);
ctx->rx.data_offset = inbuf.size - inbuf.pos;
}
break;
}
}
return 0;
}
static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len) int nonblock, int flags, int *addr_len)
{ {
struct tcp_comp_context *ctx = comp_get_ctx(sk); struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct strp_msg *rxm;
struct sk_buff *skb;
ssize_t copied = 0;
int target, err = 0;
long timeo;
flags |= nonblock;
if (unlikely(flags & MSG_ERRQUEUE))
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
lock_sock(sk);
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
timeo = sock_rcvtimeo(sk, flags & MSG_WAITALL);
do {
int chunk = 0;
skb = comp_wait_data(sk, flags, timeo, &err);
if (!skb)
goto recv_end;
if (!ctx->rx.decompressed) {
err = tcp_comp_decompress(sk, skb);
if (err < 0) {
if (err != -ENOSPC)
tcp_comp_err_abort(sk, EBADMSG);
goto recv_end;
}
ctx->rx.decompressed = true;
}
rxm = strp_msg(skb);
chunk = min_t(unsigned int, rxm->full_len, len);
err = skb_copy_datagram_msg(skb, rxm->offset, msg,
chunk);
if (err < 0)
goto recv_end;
copied += chunk;
len -= chunk;
if (likely(!(flags & MSG_PEEK)))
comp_advance_skb(sk, skb, chunk);
else
break;
if (copied >= target && !ctx->rx.pkt)
break;
} while (len > 0);
recv_end:
release_sock(sk);
return copied ? : err;
}
bool comp_stream_read(const struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
if (ctx->rx.pkt)
return true;
return false;
}
static void comp_data_ready(struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
strp_data_ready(&ctx->rx.strp);
}
static void comp_queue(struct strparser *strp, struct sk_buff *skb)
{
struct tcp_comp_context *ctx = comp_get_ctx(strp->sk);
ctx->rx.decompressed = false;
ctx->rx.pkt = skb;
strp_pause(strp);
ctx->rx.saved_data_ready(strp->sk);
}
static int comp_read_size(struct strparser *strp, struct sk_buff *skb)
{
struct strp_msg *rxm = strp_msg(skb);
if (rxm->offset > skb->len)
return 0;
return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len); return skb->len;
}
void comp_setup_strp(struct sock *sk, struct tcp_comp_context *ctx)
{
struct strp_callbacks cb;
memset(&cb, 0, sizeof(cb));
cb.rcv_msg = comp_queue;
cb.parse_msg = comp_read_size;
strp_init(&ctx->rx.strp, sk, &cb);
write_lock_bh(&sk->sk_callback_lock);
ctx->rx.saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = comp_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
strp_check_rcv(&ctx->rx.strp);
} }
static void tcp_comp_write_space(struct sock *sk) static void tcp_comp_write_space(struct sock *sk)
...@@ -483,6 +833,7 @@ void tcp_init_compression(struct sock *sk) ...@@ -483,6 +833,7 @@ void tcp_init_compression(struct sock *sk)
rcu_assign_pointer(icsk->icsk_ulp_data, ctx); rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
sock_set_flag(sk, SOCK_COMP); sock_set_flag(sk, SOCK_COMP);
comp_setup_strp(sk, ctx);
} }
static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx) static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx)
...@@ -497,6 +848,21 @@ static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx) ...@@ -497,6 +848,21 @@ static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx)
ctx->tx.compressed_data = NULL; ctx->tx.compressed_data = NULL;
} }
static void tcp_comp_context_rx_free(struct tcp_comp_context *ctx)
{
kfree(ctx->rx.dworkspace);
ctx->rx.dworkspace = NULL;
kvfree(ctx->rx.plaintext_data);
ctx->rx.plaintext_data = NULL;
kvfree(ctx->rx.compressed_data);
ctx->rx.compressed_data = NULL;
kvfree(ctx->rx.remaining_data);
ctx->rx.remaining_data = NULL;
}
static void tcp_comp_context_free(struct rcu_head *head) static void tcp_comp_context_free(struct rcu_head *head)
{ {
struct tcp_comp_context *ctx; struct tcp_comp_context *ctx;
...@@ -504,6 +870,7 @@ static void tcp_comp_context_free(struct rcu_head *head) ...@@ -504,6 +870,7 @@ static void tcp_comp_context_free(struct rcu_head *head)
ctx = container_of(head, struct tcp_comp_context, rcu); ctx = container_of(head, struct tcp_comp_context, rcu);
tcp_comp_context_tx_free(ctx); tcp_comp_context_tx_free(ctx);
tcp_comp_context_rx_free(ctx);
kfree(ctx); kfree(ctx);
} }
...@@ -515,6 +882,11 @@ void tcp_cleanup_compression(struct sock *sk) ...@@ -515,6 +882,11 @@ void tcp_cleanup_compression(struct sock *sk)
if (!ctx || !sock_flag(sk, SOCK_COMP)) if (!ctx || !sock_flag(sk, SOCK_COMP))
return; return;
if (ctx->rx.pkt) {
kfree_skb(ctx->rx.pkt);
ctx->rx.pkt = NULL;
}
rcu_assign_pointer(icsk->icsk_ulp_data, NULL); rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
call_rcu(&ctx->rcu, tcp_comp_context_free); call_rcu(&ctx->rcu, tcp_comp_context_free);
} }
...@@ -524,6 +896,7 @@ int tcp_comp_init(void) ...@@ -524,6 +896,7 @@ int tcp_comp_init(void)
tcp_prot_override = tcp_prot; tcp_prot_override = tcp_prot;
tcp_prot_override.sendmsg = tcp_comp_sendmsg; tcp_prot_override.sendmsg = tcp_comp_sendmsg;
tcp_prot_override.recvmsg = tcp_comp_recvmsg; tcp_prot_override.recvmsg = tcp_comp_recvmsg;
tcp_prot_override.stream_memory_read = comp_stream_read;
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册