提交 053ced9a 编写于 作者: W Wei Yongjun 提交者: Yang Yingliang

tcp_comp: implement sendmsg for tcp compression

hulk inclusion
category: feature
bugzilla: NA
CVE: NA

-------------------------------------------------

This patch implement software level compression for
sending tcp messages. All of the TCP payload will be
compressed before xmit.
Signed-off-by: NWei Yongjun <weiyongjun1@huawei.com>
Signed-off-by: NWang Yufen <wangyufen@huawei.com>
Reviewed-by: NYue Haibing <yuehaibing@huawei.com>
Signed-off-by: NYang Yingliang <yangyingliang@huawei.com>
上级 e9b9f606
...@@ -757,7 +757,7 @@ config TCP_MD5SIG ...@@ -757,7 +757,7 @@ config TCP_MD5SIG
config TCP_COMP config TCP_COMP
bool "TCP: Transport Layer Compression support" bool "TCP: Transport Layer Compression support"
depends on !SMC depends on !SMC && ZSTD_COMPRESS=y
---help--- ---help---
Enable kernel payload compression support for TCP protocol. This allows Enable kernel payload compression support for TCP protocol. This allows
payload compression handling of the TCP protocol to be done in-kernel. payload compression handling of the TCP protocol to be done in-kernel.
......
...@@ -6,6 +6,14 @@ ...@@ -6,6 +6,14 @@
*/ */
#include <net/tcp.h> #include <net/tcp.h>
#include <linux/zstd.h>
#define TCP_COMP_MAX_PADDING 64
#define TCP_COMP_SCRATCH_SIZE 65400
#define TCP_COMP_MAX_CSIZE (TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING)
#define TCP_COMP_SEND_PENDING 1
#define ZSTD_COMP_DEFAULT_LEVEL 1
static unsigned long tcp_compression_ports[65536 / 8]; static unsigned long tcp_compression_ports[65536 / 8];
...@@ -14,11 +22,42 @@ int sysctl_tcp_compression_local __read_mostly; ...@@ -14,11 +22,42 @@ int sysctl_tcp_compression_local __read_mostly;
static struct proto tcp_prot_override; static struct proto tcp_prot_override;
struct tcp_comp_context_tx {
ZSTD_CStream *cstream;
void *cworkspace;
void *plaintext_data;
void *compressed_data;
struct scatterlist sg_data[MAX_SKB_FRAGS];
unsigned int sg_size;
int sg_num;
struct scatterlist *partially_send;
bool in_tcp_sendpages;
};
struct tcp_comp_context { struct tcp_comp_context {
struct proto *sk_proto;
struct rcu_head rcu; struct rcu_head rcu;
struct proto *sk_proto;
void (*sk_write_space)(struct sock *sk);
struct tcp_comp_context_tx tx;
unsigned long flags;
}; };
static bool tcp_comp_is_write_pending(struct tcp_comp_context *ctx)
{
return test_bit(TCP_COMP_SEND_PENDING, &ctx->flags);
}
static void tcp_comp_err_abort(struct sock *sk, int err)
{
sk->sk_err = err;
sk->sk_error_report(sk);
}
static bool tcp_comp_enabled(__be32 saddr, __be32 daddr, int port) static bool tcp_comp_enabled(__be32 saddr, __be32 daddr, int port)
{ {
if (!sysctl_tcp_compression_local && if (!sysctl_tcp_compression_local &&
...@@ -55,11 +94,359 @@ static struct tcp_comp_context *comp_get_ctx(const struct sock *sk) ...@@ -55,11 +94,359 @@ static struct tcp_comp_context *comp_get_ctx(const struct sock *sk)
return (__force void *)icsk->icsk_ulp_data; return (__force void *)icsk->icsk_ulp_data;
} }
static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) static int tcp_comp_tx_context_init(struct tcp_comp_context *ctx)
{
ZSTD_parameters params;
int csize;
params = ZSTD_getParams(ZSTD_COMP_DEFAULT_LEVEL, PAGE_SIZE, 0);
csize = ZSTD_CStreamWorkspaceBound(params.cParams);
if (csize <= 0)
return -EINVAL;
ctx->tx.cworkspace = kmalloc(csize, GFP_KERNEL);
if (!ctx->tx.cworkspace)
return -ENOMEM;
ctx->tx.cstream = ZSTD_initCStream(params, 0, ctx->tx.cworkspace,
csize);
if (!ctx->tx.cstream)
goto err_cstream;
ctx->tx.plaintext_data = kvmalloc(TCP_COMP_SCRATCH_SIZE, GFP_KERNEL);
if (!ctx->tx.plaintext_data)
goto err_cstream;
ctx->tx.compressed_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL);
if (!ctx->tx.compressed_data)
goto err_compressed;
return 0;
err_compressed:
kvfree(ctx->tx.plaintext_data);
ctx->tx.plaintext_data = NULL;
err_cstream:
kfree(ctx->tx.cworkspace);
ctx->tx.cworkspace = NULL;
return -ENOMEM;
}
static void *tcp_comp_get_tx_stream(struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
if (!ctx->tx.plaintext_data)
tcp_comp_tx_context_init(ctx);
return ctx->tx.plaintext_data;
}
static int alloc_compressed_sg(struct sock *sk, int len)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int rc = 0;
rc = sk_alloc_sg(sk, len, ctx->tx.sg_data, 0,
&ctx->tx.sg_num, &ctx->tx.sg_size, 0);
if (rc == -ENOSPC)
ctx->tx.sg_num = ARRAY_SIZE(ctx->tx.sg_data);
return rc;
}
static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, int copy)
{
void *dest;
int rc;
dest = tcp_comp_get_tx_stream(sk);
if (!dest)
return -ENOSPC;
if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
rc = copy_from_iter_nocache(dest, copy, from);
else
rc = copy_from_iter(dest, copy, from);
if (rc != copy)
rc = -EFAULT;
return rc;
}
static int memcopy_to_sg(struct sock *sk, int bytes)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct scatterlist *sg = ctx->tx.sg_data;
char *from, *to;
int copy;
from = ctx->tx.compressed_data;
while (bytes && sg) {
to = sg_virt(sg);
copy = min_t(int, sg->length, bytes);
memcpy(to, from, copy);
bytes -= copy;
from += copy;
sg = sg_next(sg);
}
return bytes;
}
static void trim_sg(struct sock *sk, int target_size)
{ {
struct tcp_comp_context *ctx = comp_get_ctx(sk); struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct scatterlist *sg = ctx->tx.sg_data;
int trim = ctx->tx.sg_size - target_size;
int i = ctx->tx.sg_num - 1;
if (trim <= 0) {
WARN_ON_ONCE(trim < 0);
return;
}
ctx->tx.sg_size = target_size;
while (trim >= sg[i].length) {
trim -= sg[i].length;
sk_mem_uncharge(sk, sg[i].length);
put_page(sg_page(&sg[i]));
i--;
return ctx->sk_proto->sendmsg(sk, msg, size); if (i < 0)
goto out;
}
sg[i].length -= trim;
sk_mem_uncharge(sk, trim);
out:
ctx->tx.sg_num = i + 1;
sg_mark_end(ctx->tx.sg_data + ctx->tx.sg_num - 1);
}
static int tcp_comp_compress_to_sg(struct sock *sk, int bytes)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
ZSTD_outBuffer outbuf;
ZSTD_inBuffer inbuf;
size_t ret;
inbuf.src = ctx->tx.plaintext_data;
outbuf.dst = ctx->tx.compressed_data;
inbuf.size = bytes;
outbuf.size = TCP_COMP_MAX_CSIZE;
inbuf.pos = 0;
outbuf.pos = 0;
ret = ZSTD_compressStream(ctx->tx.cstream, &outbuf, &inbuf);
if (ZSTD_isError(ret))
return -EIO;
ret = ZSTD_flushStream(ctx->tx.cstream, &outbuf);
if (ZSTD_isError(ret))
return -EIO;
if (inbuf.pos != inbuf.size)
return -EIO;
if (memcopy_to_sg(sk, outbuf.pos))
return -EIO;
trim_sg(sk, outbuf.pos);
return 0;
}
static int tcp_comp_push_sg(struct sock *sk, struct scatterlist *sg, int flags)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int ret, offset;
struct page *p;
size_t size;
ctx->tx.in_tcp_sendpages = true;
while (sg) {
offset = sg->offset;
size = sg->length;
p = sg_page(sg);
retry:
ret = do_tcp_sendpages(sk, p, offset, size, flags);
if (ret != size) {
if (ret > 0) {
sk_mem_uncharge(sk, ret);
sg->offset += ret;
sg->length -= ret;
size -= ret;
offset += ret;
goto retry;
}
ctx->tx.partially_send = (void *)sg;
ctx->tx.in_tcp_sendpages = false;
return ret;
}
sk_mem_uncharge(sk, ret);
put_page(p);
sg = sg_next(sg);
}
clear_bit(TCP_COMP_SEND_PENDING, &ctx->flags);
ctx->tx.in_tcp_sendpages = false;
ctx->tx.sg_size = 0;
ctx->tx.sg_num = 0;
return 0;
}
static int tcp_comp_push(struct sock *sk, int bytes, int flags)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int ret;
ret = tcp_comp_compress_to_sg(sk, bytes);
if (ret < 0) {
pr_debug("%s: failed to compress sg\n", __func__);
return ret;
}
set_bit(TCP_COMP_SEND_PENDING, &ctx->flags);
ret = tcp_comp_push_sg(sk, ctx->tx.sg_data, flags);
if (ret) {
pr_debug("%s: failed to tcp_comp_push_sg\n", __func__);
return ret;
}
return 0;
}
static int wait_on_pending_writer(struct sock *sk, long *timeo)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
add_wait_queue(sk_sleep(sk), &wait);
while (1) {
if (!*timeo) {
ret = -EAGAIN;
break;
}
if (signal_pending(current)) {
ret = sock_intr_errno(*timeo);
break;
}
if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
break;
}
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int tcp_comp_push_pending_sg(struct sock *sk, int flags)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct scatterlist *sg;
if (!ctx->tx.partially_send)
return 0;
sg = ctx->tx.partially_send;
ctx->tx.partially_send = NULL;
return tcp_comp_push_sg(sk, sg, flags);
}
static int tcp_comp_complete_pending_work(struct sock *sk, int flags,
long *timeo)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int ret = 0;
if (unlikely(sk->sk_write_pending))
ret = wait_on_pending_writer(sk, timeo);
if (!ret && tcp_comp_is_write_pending(ctx))
ret = tcp_comp_push_pending_sg(sk, flags);
return ret;
}
static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int copied = 0, err = 0;
size_t try_to_copy;
int required_size;
long timeo;
lock_sock(sk);
timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
err = tcp_comp_complete_pending_work(sk, msg->msg_flags, &timeo);
if (err)
goto out_err;
while (msg_data_left(msg)) {
if (sk->sk_err) {
err = -sk->sk_err;
goto out_err;
}
try_to_copy = msg_data_left(msg);
if (try_to_copy > TCP_COMP_SCRATCH_SIZE)
try_to_copy = TCP_COMP_SCRATCH_SIZE;
required_size = try_to_copy + TCP_COMP_MAX_PADDING;
if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf;
alloc_compressed:
err = alloc_compressed_sg(sk, required_size);
if (err) {
if (err != -ENOSPC)
goto wait_for_memory;
goto out_err;
}
err = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
if (err < 0)
goto out_err;
copied += try_to_copy;
err = tcp_comp_push(sk, try_to_copy, msg->msg_flags);
if (err < 0) {
if (err == -ENOMEM)
goto wait_for_memory;
if (err != -EAGAIN)
tcp_comp_err_abort(sk, EBADMSG);
goto out_err;
}
continue;
wait_for_sndbuf:
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
err = sk_stream_wait_memory(sk, &timeo);
if (err)
goto out_err;
if (ctx->tx.sg_size < required_size)
goto alloc_compressed;
}
out_err:
err = sk_stream_error(sk, msg->msg_flags, err);
release_sock(sk);
return copied ? copied : err;
} }
static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
...@@ -70,6 +457,30 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -70,6 +457,30 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len); return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len);
} }
static void tcp_comp_write_space(struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
if (ctx->tx.in_tcp_sendpages) {
ctx->sk_write_space(sk);
return;
}
if (!sk->sk_write_pending && tcp_comp_is_write_pending(ctx)) {
gfp_t sk_allocation = sk->sk_allocation;
int rc;
sk->sk_allocation = GFP_ATOMIC;
rc = tcp_comp_push_pending_sg(sk, MSG_DONTWAIT | MSG_NOSIGNAL);
sk->sk_allocation = sk_allocation;
if (rc < 0)
return;
}
ctx->sk_write_space(sk);
}
void tcp_init_compression(struct sock *sk) void tcp_init_compression(struct sock *sk)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
...@@ -83,20 +494,46 @@ void tcp_init_compression(struct sock *sk) ...@@ -83,20 +494,46 @@ void tcp_init_compression(struct sock *sk)
if (!ctx) if (!ctx)
return; return;
sg_init_table(ctx->tx.sg_data, ARRAY_SIZE(ctx->tx.sg_data));
ctx->sk_write_space = sk->sk_write_space;
ctx->sk_proto = sk->sk_prot; ctx->sk_proto = sk->sk_prot;
WRITE_ONCE(sk->sk_prot, &tcp_prot_override); WRITE_ONCE(sk->sk_prot, &tcp_prot_override);
sk->sk_write_space = tcp_comp_write_space;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx); rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
sock_set_flag(sk, SOCK_COMP); sock_set_flag(sk, SOCK_COMP);
} }
static void free_sg(struct sock *sk, struct scatterlist *sg)
{
while (sg) {
sk_mem_uncharge(sk, sg->length);
put_page(sg_page(sg));
sg = sg_next(sg);
}
}
static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx)
{
kfree(ctx->tx.cworkspace);
ctx->tx.cworkspace = NULL;
kvfree(ctx->tx.plaintext_data);
ctx->tx.plaintext_data = NULL;
kvfree(ctx->tx.compressed_data);
ctx->tx.compressed_data = NULL;
}
static void tcp_comp_context_free(struct rcu_head *head) static void tcp_comp_context_free(struct rcu_head *head)
{ {
struct tcp_comp_context *ctx; struct tcp_comp_context *ctx;
ctx = container_of(head, struct tcp_comp_context, rcu); ctx = container_of(head, struct tcp_comp_context, rcu);
tcp_comp_context_tx_free(ctx);
kfree(ctx); kfree(ctx);
} }
...@@ -108,6 +545,11 @@ void tcp_cleanup_compression(struct sock *sk) ...@@ -108,6 +545,11 @@ void tcp_cleanup_compression(struct sock *sk)
if (!ctx || !sock_flag(sk, SOCK_COMP)) if (!ctx || !sock_flag(sk, SOCK_COMP))
return; return;
if (ctx->tx.partially_send) {
free_sg(sk, ctx->tx.partially_send);
ctx->tx.partially_send = NULL;
}
rcu_assign_pointer(icsk->icsk_ulp_data, NULL); rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
call_rcu(&ctx->rcu, tcp_comp_context_free); call_rcu(&ctx->rcu, tcp_comp_context_free);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册