提交 053ced9a 编写于 作者: W Wei Yongjun 提交者: Yang Yingliang

tcp_comp: implement sendmsg for tcp compression

hulk inclusion
category: feature
bugzilla: NA
CVE: NA

-------------------------------------------------

This patch implement software level compression for
sending tcp messages. All of the TCP payload will be
compressed before xmit.
Signed-off-by: NWei Yongjun <weiyongjun1@huawei.com>
Signed-off-by: NWang Yufen <wangyufen@huawei.com>
Reviewed-by: NYue Haibing <yuehaibing@huawei.com>
Signed-off-by: NYang Yingliang <yangyingliang@huawei.com>
上级 e9b9f606
......@@ -757,7 +757,7 @@ config TCP_MD5SIG
config TCP_COMP
bool "TCP: Transport Layer Compression support"
depends on !SMC
depends on !SMC && ZSTD_COMPRESS=y
---help---
Enable kernel payload compression support for TCP protocol. This allows
payload compression handling of the TCP protocol to be done in-kernel.
......
......@@ -6,6 +6,14 @@
*/
#include <net/tcp.h>
#include <linux/zstd.h>
#define TCP_COMP_MAX_PADDING 64
#define TCP_COMP_SCRATCH_SIZE 65400
#define TCP_COMP_MAX_CSIZE (TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING)
#define TCP_COMP_SEND_PENDING 1
#define ZSTD_COMP_DEFAULT_LEVEL 1
static unsigned long tcp_compression_ports[65536 / 8];
......@@ -14,11 +22,42 @@ int sysctl_tcp_compression_local __read_mostly;
static struct proto tcp_prot_override;
struct tcp_comp_context_tx {
ZSTD_CStream *cstream;
void *cworkspace;
void *plaintext_data;
void *compressed_data;
struct scatterlist sg_data[MAX_SKB_FRAGS];
unsigned int sg_size;
int sg_num;
struct scatterlist *partially_send;
bool in_tcp_sendpages;
};
struct tcp_comp_context {
struct proto *sk_proto;
struct rcu_head rcu;
struct proto *sk_proto;
void (*sk_write_space)(struct sock *sk);
struct tcp_comp_context_tx tx;
unsigned long flags;
};
static bool tcp_comp_is_write_pending(struct tcp_comp_context *ctx)
{
return test_bit(TCP_COMP_SEND_PENDING, &ctx->flags);
}
static void tcp_comp_err_abort(struct sock *sk, int err)
{
sk->sk_err = err;
sk->sk_error_report(sk);
}
static bool tcp_comp_enabled(__be32 saddr, __be32 daddr, int port)
{
if (!sysctl_tcp_compression_local &&
......@@ -55,11 +94,359 @@ static struct tcp_comp_context *comp_get_ctx(const struct sock *sk)
return (__force void *)icsk->icsk_ulp_data;
}
static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
static int tcp_comp_tx_context_init(struct tcp_comp_context *ctx)
{
ZSTD_parameters params;
int csize;
params = ZSTD_getParams(ZSTD_COMP_DEFAULT_LEVEL, PAGE_SIZE, 0);
csize = ZSTD_CStreamWorkspaceBound(params.cParams);
if (csize <= 0)
return -EINVAL;
ctx->tx.cworkspace = kmalloc(csize, GFP_KERNEL);
if (!ctx->tx.cworkspace)
return -ENOMEM;
ctx->tx.cstream = ZSTD_initCStream(params, 0, ctx->tx.cworkspace,
csize);
if (!ctx->tx.cstream)
goto err_cstream;
ctx->tx.plaintext_data = kvmalloc(TCP_COMP_SCRATCH_SIZE, GFP_KERNEL);
if (!ctx->tx.plaintext_data)
goto err_cstream;
ctx->tx.compressed_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL);
if (!ctx->tx.compressed_data)
goto err_compressed;
return 0;
err_compressed:
kvfree(ctx->tx.plaintext_data);
ctx->tx.plaintext_data = NULL;
err_cstream:
kfree(ctx->tx.cworkspace);
ctx->tx.cworkspace = NULL;
return -ENOMEM;
}
static void *tcp_comp_get_tx_stream(struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
if (!ctx->tx.plaintext_data)
tcp_comp_tx_context_init(ctx);
return ctx->tx.plaintext_data;
}
static int alloc_compressed_sg(struct sock *sk, int len)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int rc = 0;
rc = sk_alloc_sg(sk, len, ctx->tx.sg_data, 0,
&ctx->tx.sg_num, &ctx->tx.sg_size, 0);
if (rc == -ENOSPC)
ctx->tx.sg_num = ARRAY_SIZE(ctx->tx.sg_data);
return rc;
}
static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, int copy)
{
void *dest;
int rc;
dest = tcp_comp_get_tx_stream(sk);
if (!dest)
return -ENOSPC;
if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
rc = copy_from_iter_nocache(dest, copy, from);
else
rc = copy_from_iter(dest, copy, from);
if (rc != copy)
rc = -EFAULT;
return rc;
}
static int memcopy_to_sg(struct sock *sk, int bytes)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct scatterlist *sg = ctx->tx.sg_data;
char *from, *to;
int copy;
from = ctx->tx.compressed_data;
while (bytes && sg) {
to = sg_virt(sg);
copy = min_t(int, sg->length, bytes);
memcpy(to, from, copy);
bytes -= copy;
from += copy;
sg = sg_next(sg);
}
return bytes;
}
static void trim_sg(struct sock *sk, int target_size)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct scatterlist *sg = ctx->tx.sg_data;
int trim = ctx->tx.sg_size - target_size;
int i = ctx->tx.sg_num - 1;
if (trim <= 0) {
WARN_ON_ONCE(trim < 0);
return;
}
ctx->tx.sg_size = target_size;
while (trim >= sg[i].length) {
trim -= sg[i].length;
sk_mem_uncharge(sk, sg[i].length);
put_page(sg_page(&sg[i]));
i--;
return ctx->sk_proto->sendmsg(sk, msg, size);
if (i < 0)
goto out;
}
sg[i].length -= trim;
sk_mem_uncharge(sk, trim);
out:
ctx->tx.sg_num = i + 1;
sg_mark_end(ctx->tx.sg_data + ctx->tx.sg_num - 1);
}
static int tcp_comp_compress_to_sg(struct sock *sk, int bytes)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
ZSTD_outBuffer outbuf;
ZSTD_inBuffer inbuf;
size_t ret;
inbuf.src = ctx->tx.plaintext_data;
outbuf.dst = ctx->tx.compressed_data;
inbuf.size = bytes;
outbuf.size = TCP_COMP_MAX_CSIZE;
inbuf.pos = 0;
outbuf.pos = 0;
ret = ZSTD_compressStream(ctx->tx.cstream, &outbuf, &inbuf);
if (ZSTD_isError(ret))
return -EIO;
ret = ZSTD_flushStream(ctx->tx.cstream, &outbuf);
if (ZSTD_isError(ret))
return -EIO;
if (inbuf.pos != inbuf.size)
return -EIO;
if (memcopy_to_sg(sk, outbuf.pos))
return -EIO;
trim_sg(sk, outbuf.pos);
return 0;
}
static int tcp_comp_push_sg(struct sock *sk, struct scatterlist *sg, int flags)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int ret, offset;
struct page *p;
size_t size;
ctx->tx.in_tcp_sendpages = true;
while (sg) {
offset = sg->offset;
size = sg->length;
p = sg_page(sg);
retry:
ret = do_tcp_sendpages(sk, p, offset, size, flags);
if (ret != size) {
if (ret > 0) {
sk_mem_uncharge(sk, ret);
sg->offset += ret;
sg->length -= ret;
size -= ret;
offset += ret;
goto retry;
}
ctx->tx.partially_send = (void *)sg;
ctx->tx.in_tcp_sendpages = false;
return ret;
}
sk_mem_uncharge(sk, ret);
put_page(p);
sg = sg_next(sg);
}
clear_bit(TCP_COMP_SEND_PENDING, &ctx->flags);
ctx->tx.in_tcp_sendpages = false;
ctx->tx.sg_size = 0;
ctx->tx.sg_num = 0;
return 0;
}
static int tcp_comp_push(struct sock *sk, int bytes, int flags)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int ret;
ret = tcp_comp_compress_to_sg(sk, bytes);
if (ret < 0) {
pr_debug("%s: failed to compress sg\n", __func__);
return ret;
}
set_bit(TCP_COMP_SEND_PENDING, &ctx->flags);
ret = tcp_comp_push_sg(sk, ctx->tx.sg_data, flags);
if (ret) {
pr_debug("%s: failed to tcp_comp_push_sg\n", __func__);
return ret;
}
return 0;
}
static int wait_on_pending_writer(struct sock *sk, long *timeo)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
add_wait_queue(sk_sleep(sk), &wait);
while (1) {
if (!*timeo) {
ret = -EAGAIN;
break;
}
if (signal_pending(current)) {
ret = sock_intr_errno(*timeo);
break;
}
if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
break;
}
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int tcp_comp_push_pending_sg(struct sock *sk, int flags)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct scatterlist *sg;
if (!ctx->tx.partially_send)
return 0;
sg = ctx->tx.partially_send;
ctx->tx.partially_send = NULL;
return tcp_comp_push_sg(sk, sg, flags);
}
static int tcp_comp_complete_pending_work(struct sock *sk, int flags,
long *timeo)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int ret = 0;
if (unlikely(sk->sk_write_pending))
ret = wait_on_pending_writer(sk, timeo);
if (!ret && tcp_comp_is_write_pending(ctx))
ret = tcp_comp_push_pending_sg(sk, flags);
return ret;
}
static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
int copied = 0, err = 0;
size_t try_to_copy;
int required_size;
long timeo;
lock_sock(sk);
timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
err = tcp_comp_complete_pending_work(sk, msg->msg_flags, &timeo);
if (err)
goto out_err;
while (msg_data_left(msg)) {
if (sk->sk_err) {
err = -sk->sk_err;
goto out_err;
}
try_to_copy = msg_data_left(msg);
if (try_to_copy > TCP_COMP_SCRATCH_SIZE)
try_to_copy = TCP_COMP_SCRATCH_SIZE;
required_size = try_to_copy + TCP_COMP_MAX_PADDING;
if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf;
alloc_compressed:
err = alloc_compressed_sg(sk, required_size);
if (err) {
if (err != -ENOSPC)
goto wait_for_memory;
goto out_err;
}
err = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
if (err < 0)
goto out_err;
copied += try_to_copy;
err = tcp_comp_push(sk, try_to_copy, msg->msg_flags);
if (err < 0) {
if (err == -ENOMEM)
goto wait_for_memory;
if (err != -EAGAIN)
tcp_comp_err_abort(sk, EBADMSG);
goto out_err;
}
continue;
wait_for_sndbuf:
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
err = sk_stream_wait_memory(sk, &timeo);
if (err)
goto out_err;
if (ctx->tx.sg_size < required_size)
goto alloc_compressed;
}
out_err:
err = sk_stream_error(sk, msg->msg_flags, err);
release_sock(sk);
return copied ? copied : err;
}
static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
......@@ -70,6 +457,30 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len);
}
static void tcp_comp_write_space(struct sock *sk)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
if (ctx->tx.in_tcp_sendpages) {
ctx->sk_write_space(sk);
return;
}
if (!sk->sk_write_pending && tcp_comp_is_write_pending(ctx)) {
gfp_t sk_allocation = sk->sk_allocation;
int rc;
sk->sk_allocation = GFP_ATOMIC;
rc = tcp_comp_push_pending_sg(sk, MSG_DONTWAIT | MSG_NOSIGNAL);
sk->sk_allocation = sk_allocation;
if (rc < 0)
return;
}
ctx->sk_write_space(sk);
}
void tcp_init_compression(struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
......@@ -83,20 +494,46 @@ void tcp_init_compression(struct sock *sk)
if (!ctx)
return;
sg_init_table(ctx->tx.sg_data, ARRAY_SIZE(ctx->tx.sg_data));
ctx->sk_write_space = sk->sk_write_space;
ctx->sk_proto = sk->sk_prot;
WRITE_ONCE(sk->sk_prot, &tcp_prot_override);
sk->sk_write_space = tcp_comp_write_space;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
sock_set_flag(sk, SOCK_COMP);
}
static void free_sg(struct sock *sk, struct scatterlist *sg)
{
while (sg) {
sk_mem_uncharge(sk, sg->length);
put_page(sg_page(sg));
sg = sg_next(sg);
}
}
static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx)
{
kfree(ctx->tx.cworkspace);
ctx->tx.cworkspace = NULL;
kvfree(ctx->tx.plaintext_data);
ctx->tx.plaintext_data = NULL;
kvfree(ctx->tx.compressed_data);
ctx->tx.compressed_data = NULL;
}
static void tcp_comp_context_free(struct rcu_head *head)
{
struct tcp_comp_context *ctx;
ctx = container_of(head, struct tcp_comp_context, rcu);
tcp_comp_context_tx_free(ctx);
kfree(ctx);
}
......@@ -108,6 +545,11 @@ void tcp_cleanup_compression(struct sock *sk)
if (!ctx || !sock_flag(sk, SOCK_COMP))
return;
if (ctx->tx.partially_send) {
free_sg(sk, ctx->tx.partially_send);
ctx->tx.partially_send = NULL;
}
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
call_rcu(&ctx->rcu, tcp_comp_context_free);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册