提交 2eaa8575 编写于 作者: D David S. Miller

Merge branch 'net-tls-fix-scatter-gather-list-issues'

Jakub Kicinski says:

====================
net: tls: fix scatter-gather list issues

This series kicked of by a syzbot report fixes three issues around
scatter gather handling in the TLS code. First patch fixes a use-
-after-free situation which may occur if record was freed on error.
This could have already happened in BPF paths, and patch 2 now makes
the same condition occur in non-BPF code.

Patch 2 fixes the problem spotted by syzbot. If encryption failed
we have to clean the end markings from scatter gather list. As
suggested by John the patch frees the record entirely and caller
may retry copying data from user space buffer again.

Third patch fixes a bug in the TLS 1.3 code spotted while working
on patch 2. TLS 1.3 may effectively overflow the SG list which
leads to the BUG() in sg_page() being triggered.

Patch 4 adds a test case which triggers this bug reliably.

Next two patches are small cleanups of dead code and code which
makes dangerous assumptions.

Last but not least two minor improvements to the sockmap tests.

Tested:
 - bpf/test_sockmap
 - net/tls
 - syzbot repro (which used error injection, hence no direct
   selftest is added to preserve it).
====================
Signed-off-by: NDavid S. Miller <davem@davemloft.net>
......@@ -14,6 +14,7 @@
#include <net/strparser.h>
#define MAX_MSG_FRAGS MAX_SKB_FRAGS
#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
enum __sk_action {
__SK_DROP = 0,
......@@ -29,13 +30,15 @@ struct sk_msg_sg {
u32 size;
u32 copybreak;
unsigned long copy;
/* The extra element is used for chaining the front and sections when
* the list becomes partitioned (e.g. end < start). The crypto APIs
* require the chaining.
/* The extra two elements:
* 1) used for chaining the front and sections when the list becomes
* partitioned (e.g. end < start). The crypto APIs require the
* chaining;
* 2) to chain tailer SG entries after the message.
*/
struct scatterlist data[MAX_MSG_FRAGS + 1];
struct scatterlist data[MAX_MSG_FRAGS + 2];
};
static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS);
static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
struct sk_msg {
......@@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
static inline u32 sk_msg_iter_dist(u32 start, u32 end)
{
return end >= start ? end - start : end + (MAX_MSG_FRAGS - start);
return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
}
#define sk_msg_iter_var_prev(var) \
do { \
if (var == 0) \
var = MAX_MSG_FRAGS - 1; \
var = NR_MSG_FRAG_IDS - 1; \
else \
var--; \
} while (0)
......@@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end)
#define sk_msg_iter_var_next(var) \
do { \
var++; \
if (var == MAX_MSG_FRAGS) \
if (var == NR_MSG_FRAG_IDS) \
var = 0; \
} while (0)
......@@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
static inline void sk_msg_init(struct sk_msg *msg)
{
BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
memset(msg, 0, sizeof(*msg));
sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
}
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
......@@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
static inline bool sk_msg_full(const struct sk_msg *msg)
{
return (msg->sg.end == msg->sg.start) && msg->sg.size;
return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
}
static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{
if (sk_msg_full(msg))
return MAX_MSG_FRAGS;
return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
}
......
......@@ -100,7 +100,6 @@ struct tls_rec {
struct list_head list;
int tx_ready;
int tx_flags;
int inplace_crypto;
struct sk_msg msg_plaintext;
struct sk_msg msg_encrypted;
......@@ -377,7 +376,7 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx,
int flags);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags);
bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
......
......@@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
WARN_ON_ONCE(last_sge == first_sge);
shift = last_sge > first_sge ?
last_sge - first_sge - 1 :
MAX_SKB_FRAGS - first_sge + last_sge - 1;
NR_MSG_FRAG_IDS - first_sge + last_sge - 1;
if (!shift)
goto out;
......@@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
do {
u32 move_from;
if (i + shift >= MAX_MSG_FRAGS)
move_from = i + shift - MAX_MSG_FRAGS;
if (i + shift >= NR_MSG_FRAG_IDS)
move_from = i + shift - NR_MSG_FRAG_IDS;
else
move_from = i + shift;
if (move_from == msg->sg.end)
......@@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
} while (1);
msg->sg.end = msg->sg.end - shift > msg->sg.end ?
msg->sg.end - shift + MAX_MSG_FRAGS :
msg->sg.end - shift + NR_MSG_FRAG_IDS :
msg->sg.end - shift;
out:
msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
......
......@@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
copied = skb->len;
msg->sg.start = 0;
msg->sg.size = copied;
msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
msg->sg.end = num_sge;
msg->skb = skb;
sk_psock_queue_msg(psock, msg);
......
......@@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
struct sk_msg *msg, int *copied, int flags)
{
bool cork = false, enospc = msg->sg.start == msg->sg.end;
bool cork = false, enospc = sk_msg_full(msg);
struct sock *sk_redir;
u32 tosend, delta = 0;
int ret;
......
......@@ -209,24 +209,15 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
return tls_push_sg(sk, ctx, sg, offset, flags);
}
bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
{
struct scatterlist *sg;
sg = ctx->partially_sent_record;
if (!sg)
return false;
while (1) {
for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
put_page(sg_page(sg));
sk_mem_uncharge(sk, sg->length);
if (sg_is_last(sg))
break;
sg++;
}
ctx->partially_sent_record = NULL;
return true;
}
static void tls_write_space(struct sock *sk)
......
......@@ -710,8 +710,7 @@ static int tls_push_record(struct sock *sk, int flags,
}
i = msg_pl->sg.start;
sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ?
&msg_en->sg.data[i] : &msg_pl->sg.data[i]);
sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
i = msg_en->sg.end;
sk_msg_iter_var_prev(i);
......@@ -771,8 +770,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
policy = !(flags & MSG_SENDPAGE_NOPOLICY);
psock = sk_psock_get(sk);
if (!psock || !policy)
return tls_push_record(sk, flags, record_type);
if (!psock || !policy) {
err = tls_push_record(sk, flags, record_type);
if (err) {
*copied -= sk_msg_free(sk, msg);
tls_free_open_rec(sk);
}
return err;
}
more_data:
enospc = sk_msg_full(msg);
if (psock->eval == __SK_NONE) {
......@@ -970,8 +975,6 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
if (ret)
goto fallback_to_reg_send;
rec->inplace_crypto = 0;
num_zc++;
copied += try_to_copy;
......@@ -984,7 +987,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
num_async++;
else if (ret == -ENOMEM)
goto wait_for_memory;
else if (ret == -ENOSPC)
else if (ctx->open_rec && ret == -ENOSPC)
goto rollback_iter;
else if (ret != -EAGAIN)
goto send_end;
......@@ -1053,11 +1056,12 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
ret = sk_stream_wait_memory(sk, &timeo);
if (ret) {
trim_sgl:
tls_trim_both_msgs(sk, orig_size);
if (ctx->open_rec)
tls_trim_both_msgs(sk, orig_size);
goto send_end;
}
if (msg_en->sg.size < required_size)
if (ctx->open_rec && msg_en->sg.size < required_size)
goto alloc_encrypted;
}
......@@ -1169,7 +1173,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
tls_ctx->pending_open_record_frags = true;
if (full_record || eor || sk_msg_full(msg_pl)) {
rec->inplace_crypto = 0;
ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
record_type, &copied, flags);
if (ret) {
......@@ -1190,11 +1193,13 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
wait_for_memory:
ret = sk_stream_wait_memory(sk, &timeo);
if (ret) {
tls_trim_both_msgs(sk, msg_pl->sg.size);
if (ctx->open_rec)
tls_trim_both_msgs(sk, msg_pl->sg.size);
goto sendpage_end;
}
goto alloc_payload;
if (ctx->open_rec)
goto alloc_payload;
}
if (num_async) {
......@@ -2084,7 +2089,8 @@ void tls_sw_release_resources_tx(struct sock *sk)
/* Free up un-sent records in tx_list. First, free
* the partially sent record if any at head of tx_list.
*/
if (tls_free_partial_record(sk, tls_ctx)) {
if (tls_ctx->partially_sent_record) {
tls_free_partial_record(sk, tls_ctx);
rec = list_first_entry(&ctx->tx_list,
struct tls_rec, list);
list_del(&rec->list);
......
......@@ -240,14 +240,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT);
err = bind(s1, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) {
perror("bind s1 failed()\n");
perror("bind s1 failed()");
return errno;
}
addr.sin_port = htons(S2_PORT);
err = bind(s2, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) {
perror("bind s2 failed()\n");
perror("bind s2 failed()");
return errno;
}
......@@ -255,14 +255,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT);
err = listen(s1, 32);
if (err < 0) {
perror("listen s1 failed()\n");
perror("listen s1 failed()");
return errno;
}
addr.sin_port = htons(S2_PORT);
err = listen(s2, 32);
if (err < 0) {
perror("listen s1 failed()\n");
perror("listen s1 failed()");
return errno;
}
......@@ -270,14 +270,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT);
err = connect(c1, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0 && errno != EINPROGRESS) {
perror("connect c1 failed()\n");
perror("connect c1 failed()");
return errno;
}
addr.sin_port = htons(S2_PORT);
err = connect(c2, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0 && errno != EINPROGRESS) {
perror("connect c2 failed()\n");
perror("connect c2 failed()");
return errno;
} else if (err < 0) {
err = 0;
......@@ -286,13 +286,13 @@ static int sockmap_init_sockets(int verbose)
/* Accept Connecrtions */
p1 = accept(s1, NULL, NULL);
if (p1 < 0) {
perror("accept s1 failed()\n");
perror("accept s1 failed()");
return errno;
}
p2 = accept(s2, NULL, NULL);
if (p2 < 0) {
perror("accept s1 failed()\n");
perror("accept s1 failed()");
return errno;
}
......@@ -332,6 +332,10 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
int i, fp;
file = fopen(".sendpage_tst.tmp", "w+");
if (!file) {
perror("create file for sendpage");
return 1;
}
for (i = 0; i < iov_length * cnt; i++, k++)
fwrite(&k, sizeof(char), 1, file);
fflush(file);
......@@ -339,12 +343,17 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
fclose(file);
fp = open(".sendpage_tst.tmp", O_RDONLY);
if (fp < 0) {
perror("reopen file for sendpage");
return 1;
}
clock_gettime(CLOCK_MONOTONIC, &s->start);
for (i = 0; i < cnt; i++) {
int sent = sendfile(fd, fp, NULL, iov_length);
if (!drop && sent < 0) {
perror("send loop error:");
perror("send loop error");
close(fp);
return sent;
} else if (drop && sent >= 0) {
......@@ -463,7 +472,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
int sent = sendmsg(fd, &msg, flags);
if (!drop && sent < 0) {
perror("send loop error:");
perror("send loop error");
goto out_errno;
} else if (drop && sent >= 0) {
printf("send loop error expected: %i\n", sent);
......@@ -499,7 +508,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
total_bytes -= txmsg_pop_total;
err = clock_gettime(CLOCK_MONOTONIC, &s->start);
if (err < 0)
perror("recv start time: ");
perror("recv start time");
while (s->bytes_recvd < total_bytes) {
if (txmsg_cork) {
timeout.tv_sec = 0;
......@@ -543,7 +552,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
if (recv < 0) {
if (errno != EWOULDBLOCK) {
clock_gettime(CLOCK_MONOTONIC, &s->end);
perror("recv failed()\n");
perror("recv failed()");
goto out_errno;
}
}
......@@ -557,7 +566,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
errno = msg_verify_data(&msg, recv, chunk_sz);
if (errno) {
perror("data verify msg failed\n");
perror("data verify msg failed");
goto out_errno;
}
if (recvp) {
......@@ -565,7 +574,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
recvp,
chunk_sz);
if (errno) {
perror("data verify msg_peek failed\n");
perror("data verify msg_peek failed");
goto out_errno;
}
}
......@@ -654,7 +663,7 @@ static int sendmsg_test(struct sockmap_options *opt)
err = 0;
exit(err ? 1 : 0);
} else if (rxpid == -1) {
perror("msg_loop_rx: ");
perror("msg_loop_rx");
return errno;
}
......@@ -681,7 +690,7 @@ static int sendmsg_test(struct sockmap_options *opt)
s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
exit(err ? 1 : 0);
} else if (txpid == -1) {
perror("msg_loop_tx: ");
perror("msg_loop_tx");
return errno;
}
......@@ -715,7 +724,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
/* Ping/Pong data from client to server */
sc = send(c1, buf, sizeof(buf), 0);
if (sc < 0) {
perror("send failed()\n");
perror("send failed()");
return sc;
}
......@@ -748,7 +757,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
rc = recv(i, buf, sizeof(buf), 0);
if (rc < 0) {
if (errno != EWOULDBLOCK) {
perror("recv failed()\n");
perror("recv failed()");
return rc;
}
}
......@@ -760,7 +769,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
sc = send(i, buf, rc, 0);
if (sc < 0) {
perror("send failed()\n");
perror("send failed()");
return sc;
}
}
......
......@@ -45,7 +45,7 @@ static int get_stats(int fd, __u16 count, __u32 raddr)
printf("\nXDP RTT data:\n");
if (bpf_map_lookup_elem(fd, &raddr, &pinginfo)) {
perror("bpf_map_lookup elem: ");
perror("bpf_map_lookup elem");
return 1;
}
......
......@@ -268,6 +268,38 @@ TEST_F(tls, sendmsg_single)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
}
#define MAX_FRAGS 64
#define SEND_LEN 13
TEST_F(tls, sendmsg_fragmented)
{
char const *test_str = "test_sendmsg";
char buf[SEND_LEN * MAX_FRAGS];
struct iovec vec[MAX_FRAGS];
struct msghdr msg;
int i, frags;
for (frags = 1; frags <= MAX_FRAGS; frags++) {
for (i = 0; i < frags; i++) {
vec[i].iov_base = (char *)test_str;
vec[i].iov_len = SEND_LEN;
}
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_iov = vec;
msg.msg_iovlen = frags;
EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
SEND_LEN * frags);
for (i = 0; i < frags; i++)
EXPECT_EQ(memcmp(buf + SEND_LEN * i,
test_str, SEND_LEN), 0);
}
}
#undef MAX_FRAGS
#undef SEND_LEN
TEST_F(tls, sendmsg_large)
{
void *mem = malloc(16384);
......@@ -694,6 +726,34 @@ TEST_F(tls, recv_lowat)
EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
}
TEST_F(tls, recv_rcvbuf)
{
char send_mem[4096];
char recv_mem[4096];
int rcv_buf = 1024;
memset(send_mem, 0x1c, sizeof(send_mem));
EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVBUF,
&rcv_buf, sizeof(rcv_buf)), 0);
EXPECT_EQ(send(self->fd, send_mem, 512, 0), 512);
memset(recv_mem, 0, sizeof(recv_mem));
EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), 512);
EXPECT_EQ(memcmp(send_mem, recv_mem, 512), 0);
if (self->notls)
return;
EXPECT_EQ(send(self->fd, send_mem, 4096, 0), 4096);
memset(recv_mem, 0, sizeof(recv_mem));
EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1);
EXPECT_EQ(errno, EMSGSIZE);
EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1);
EXPECT_EQ(errno, EMSGSIZE);
}
TEST_F(tls, bidir)
{
char const *test_str = "test_read";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册