提交 c5f3ee69 编写于 作者: W Wang Yufen 提交者: Zheng Zengkai

tcp_comp: Add dpkt to save decompressed skb

hulk inclusion
category: bugfix
bugzilla: https://gitee.com/openeuler/kernel/issues/I48H9Z?from=project-issue
CVE: NA

-------------------------------------------------

In order to separate the compressed data and decompressed data, this patch adds
dpkt to tcp_comp_context_rx, dpkt is used to save decompressed skb.
Signed-off-by: NWang Yufen <wangyufen@huawei.com>
Signed-off-by: NYang Yingliang <yangyingliang@huawei.com>
Signed-off-by: NLu Wei <luwei32@huawei.com>
Reviewed-by: NWei Yongjun <weiyongjun1@huawei.com>
Signed-off-by: NZheng Zengkai <zhengzengkai@huawei.com>
上级 d876fdbb
...@@ -45,7 +45,7 @@ struct tcp_comp_context_rx { ...@@ -45,7 +45,7 @@ struct tcp_comp_context_rx {
struct strparser strp; struct strparser strp;
void (*saved_data_ready)(struct sock *sk); void (*saved_data_ready)(struct sock *sk);
struct sk_buff *pkt; struct sk_buff *pkt;
bool decompressed; struct sk_buff *dpkt;
}; };
struct tcp_comp_context { struct tcp_comp_context {
...@@ -510,6 +510,24 @@ static bool comp_advance_skb(struct sock *sk, struct sk_buff *skb, ...@@ -510,6 +510,24 @@ static bool comp_advance_skb(struct sock *sk, struct sk_buff *skb,
return true; return true;
} }
static bool comp_advance_dskb(struct sock *sk, struct sk_buff *skb,
unsigned int len)
{
struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct strp_msg *rxm = strp_msg(skb);
if (len < rxm->full_len) {
rxm->offset += len;
rxm->full_len -= len;
return false;
}
/* Finished with message */
ctx->rx.dpkt = NULL;
kfree_skb(skb);
return true;
}
static int tcp_comp_rx_context_init(struct tcp_comp_context *ctx) static int tcp_comp_rx_context_init(struct tcp_comp_context *ctx)
{ {
int dsize; int dsize;
...@@ -566,13 +584,14 @@ static void *tcp_comp_get_rx_stream(struct sock *sk) ...@@ -566,13 +584,14 @@ static void *tcp_comp_get_rx_stream(struct sock *sk)
return ctx->rx.plaintext_data; return ctx->rx.plaintext_data;
} }
static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb) static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb, int flags)
{ {
struct tcp_comp_context *ctx = comp_get_ctx(sk); struct tcp_comp_context *ctx = comp_get_ctx(sk);
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
const int plen = skb->len; const int plen = skb->len;
ZSTD_outBuffer outbuf; ZSTD_outBuffer outbuf;
ZSTD_inBuffer inbuf; ZSTD_inBuffer inbuf;
struct sk_buff *nskb;
int len; int len;
void *to; void *to;
...@@ -586,6 +605,10 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb) ...@@ -586,6 +605,10 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
if (plen + ctx->rx.data_offset > TCP_COMP_MAX_CSIZE) if (plen + ctx->rx.data_offset > TCP_COMP_MAX_CSIZE)
return -ENOMEM; return -ENOMEM;
nskb = skb_copy(skb, GFP_KERNEL);
if (!nskb)
return -ENOMEM;
if (ctx->rx.data_offset) if (ctx->rx.data_offset)
memcpy(ctx->rx.compressed_data, ctx->rx.remaining_data, memcpy(ctx->rx.compressed_data, ctx->rx.remaining_data,
ctx->rx.data_offset); ctx->rx.data_offset);
...@@ -607,34 +630,38 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb) ...@@ -607,34 +630,38 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
to = outbuf.dst; to = outbuf.dst;
ret = ZSTD_decompressStream(ctx->rx.dstream, &outbuf, &inbuf); ret = ZSTD_decompressStream(ctx->rx.dstream, &outbuf, &inbuf);
if (ZSTD_isError(ret)) if (ZSTD_isError(ret)) {
kfree_skb(nskb);
return -EIO; return -EIO;
}
len = outbuf.pos - plen; len = outbuf.pos - plen;
if (len > skb_tailroom(skb)) if (len > skb_tailroom(nskb))
len = skb_tailroom(skb); len = skb_tailroom(nskb);
__skb_put(skb, len); __skb_put(nskb, len);
rxm->full_len += (len + rxm->offset);
rxm->offset = 0;
len += plen; len += plen;
skb_copy_to_linear_data(skb, to, len); skb_copy_to_linear_data(nskb, to, len);
while ((to += len, outbuf.pos -= len) > 0) { while ((to += len, outbuf.pos -= len) > 0) {
struct page *pages; struct page *pages;
skb_frag_t *frag; skb_frag_t *frag;
if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS)) if (WARN_ON(skb_shinfo(nskb)->nr_frags >= MAX_SKB_FRAGS)) {
kfree_skb(nskb);
return -EMSGSIZE; return -EMSGSIZE;
}
frag = skb_shinfo(skb)->frags + frag = skb_shinfo(nskb)->frags +
skb_shinfo(skb)->nr_frags; skb_shinfo(nskb)->nr_frags;
pages = alloc_pages(__GFP_NOWARN | GFP_KERNEL | __GFP_COMP, pages = alloc_pages(__GFP_NOWARN | GFP_KERNEL | __GFP_COMP,
TCP_COMP_ALLOC_ORDER); TCP_COMP_ALLOC_ORDER);
if (!pages) if (!pages) {
kfree_skb(nskb);
return -ENOMEM; return -ENOMEM;
}
__skb_frag_set_page(frag, pages); __skb_frag_set_page(frag, pages);
len = PAGE_SIZE << TCP_COMP_ALLOC_ORDER; len = PAGE_SIZE << TCP_COMP_ALLOC_ORDER;
...@@ -645,11 +672,10 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb) ...@@ -645,11 +672,10 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
skb_frag_size_set(frag, len); skb_frag_size_set(frag, len);
memcpy(skb_frag_address(frag), to, len); memcpy(skb_frag_address(frag), to, len);
skb->truesize += len; nskb->truesize += len;
skb->data_len += len; nskb->data_len += len;
skb->len += len; nskb->len += len;
rxm->full_len += len; skb_shinfo(nskb)->nr_frags++;
skb_shinfo(skb)->nr_frags++;
} }
if (ret == 0) if (ret == 0)
...@@ -665,6 +691,13 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb) ...@@ -665,6 +691,13 @@ static int tcp_comp_decompress(struct sock *sk, struct sk_buff *skb)
break; break;
} }
} }
ctx->rx.dpkt = nskb;
rxm = strp_msg(nskb);
rxm->full_len = nskb->len;
rxm->offset = 0;
comp_advance_skb(sk, skb, plen - rxm->offset);
return 0; return 0;
} }
...@@ -691,21 +724,19 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -691,21 +724,19 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
do { do {
int chunk = 0; int chunk = 0;
skb = comp_wait_data(sk, flags, timeo, &err); if (!ctx->rx.dpkt) {
if (!skb) skb = comp_wait_data(sk, flags, timeo, &err);
goto recv_end; if (!skb)
goto recv_end;
if (!ctx->rx.decompressed) { err = tcp_comp_decompress(sk, skb, flags);
err = tcp_comp_decompress(sk, skb);
if (err < 0) { if (err < 0) {
goto recv_end; goto recv_end;
} }
ctx->rx.decompressed = true;
} }
skb = ctx->rx.dpkt;
rxm = strp_msg(skb); rxm = strp_msg(skb);
chunk = min_t(unsigned int, rxm->full_len, len); chunk = min_t(unsigned int, rxm->full_len, len);
err = skb_copy_datagram_msg(skb, rxm->offset, msg, err = skb_copy_datagram_msg(skb, rxm->offset, msg,
chunk); chunk);
if (err < 0) if (err < 0)
...@@ -714,11 +745,11 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -714,11 +745,11 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
copied += chunk; copied += chunk;
len -= chunk; len -= chunk;
if (likely(!(flags & MSG_PEEK))) if (likely(!(flags & MSG_PEEK)))
comp_advance_skb(sk, skb, chunk); comp_advance_dskb(sk, skb, chunk);
else else
break; break;
if (copied >= target && !ctx->rx.pkt) if (copied >= target && !ctx->rx.dpkt)
break; break;
} while (len > 0); } while (len > 0);
...@@ -734,7 +765,7 @@ bool comp_stream_read(const struct sock *sk) ...@@ -734,7 +765,7 @@ bool comp_stream_read(const struct sock *sk)
if (!ctx) if (!ctx)
return false; return false;
if (ctx->rx.pkt) if (ctx->rx.pkt || ctx->rx.dpkt)
return true; return true;
return false; return false;
...@@ -751,7 +782,6 @@ static void comp_queue(struct strparser *strp, struct sk_buff *skb) ...@@ -751,7 +782,6 @@ static void comp_queue(struct strparser *strp, struct sk_buff *skb)
{ {
struct tcp_comp_context *ctx = comp_get_ctx(strp->sk); struct tcp_comp_context *ctx = comp_get_ctx(strp->sk);
ctx->rx.decompressed = false;
ctx->rx.pkt = skb; ctx->rx.pkt = skb;
strp_pause(strp); strp_pause(strp);
ctx->rx.saved_data_ready(strp->sk); ctx->rx.saved_data_ready(strp->sk);
...@@ -887,6 +917,10 @@ void tcp_cleanup_compression(struct sock *sk) ...@@ -887,6 +917,10 @@ void tcp_cleanup_compression(struct sock *sk)
kfree_skb(ctx->rx.pkt); kfree_skb(ctx->rx.pkt);
ctx->rx.pkt = NULL; ctx->rx.pkt = NULL;
} }
if (ctx->rx.dpkt) {
kfree_skb(ctx->rx.dpkt);
ctx->rx.dpkt = NULL;
}
strp_stop(&ctx->rx.strp); strp_stop(&ctx->rx.strp);
rcu_assign_pointer(icsk->icsk_ulp_data, NULL); rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册