diff --git a/include/net/tls.h b/include/net/tls.h index a8b37226a28795aa2e961791397d5304de3ee751..9f4117ae22973e30696aa42623aef44d8a3a6c57 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -129,6 +129,11 @@ struct tls_rec { u8 aead_req_ctx[]; }; +struct tls_msg { + struct strp_msg rxm; + u8 control; +}; + struct tx_work { struct delayed_work work; struct sock *sk; @@ -333,6 +338,11 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, int flags, long *timeo); +static inline struct tls_msg *tls_msg(struct sk_buff *skb) +{ + return (struct tls_msg *)strp_msg(skb); +} + static inline bool tls_is_pending_closed_record(struct tls_context *ctx) { return test_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 71be8acfbc9b382c53d07b0f677636b5f34ef24c..1cc830582fa8af2e0ff978005eba5c5a609ae51b 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1530,22 +1530,38 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, } /* This function traverses the rx_list in tls receive context to copies the - * decrypted data records into the buffer provided by caller zero copy is not + * decrypted records into the buffer provided by caller zero copy is not * true. Further, the records are removed from the rx_list if it is not a peek * case and the record has been consumed completely. */ static int process_rx_list(struct tls_sw_context_rx *ctx, struct msghdr *msg, + u8 *control, + bool *cmsg, size_t skip, size_t len, bool zc, bool is_peek) { struct sk_buff *skb = skb_peek(&ctx->rx_list); + u8 ctrl = *control; + u8 msgc = *cmsg; + struct tls_msg *tlm; ssize_t copied = 0; + /* Set the record type in 'control' if caller didn't pass it */ + if (!ctrl && skb) { + tlm = tls_msg(skb); + ctrl = tlm->control; + } + while (skip && skb) { struct strp_msg *rxm = strp_msg(skb); + tlm = tls_msg(skb); + + /* Cannot process a record of different type */ + if (ctrl != tlm->control) + return 0; if (skip < rxm->full_len) break; @@ -1559,6 +1575,27 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, struct strp_msg *rxm = strp_msg(skb); int chunk = min_t(unsigned int, rxm->full_len - skip, len); + tlm = tls_msg(skb); + + /* Cannot process a record of different type */ + if (ctrl != tlm->control) + return 0; + + /* Set record type if not already done. For a non-data record, + * do not proceed if record type could not be copied. + */ + if (!msgc) { + int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(ctrl), &ctrl); + msgc = true; + if (ctrl != TLS_RECORD_TYPE_DATA) { + if (cerr || msg->msg_flags & MSG_CTRUNC) + return -EIO; + + *cmsg = msgc; + } + } + if (!zc || (rxm->full_len - skip) > len) { int err = skb_copy_datagram_msg(skb, rxm->offset + skip, msg, chunk); @@ -1597,6 +1634,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, skb = next_skb; } + *control = ctrl; return copied; } @@ -1614,6 +1652,7 @@ int tls_sw_recvmsg(struct sock *sk, unsigned char control = 0; ssize_t decrypted = 0; struct strp_msg *rxm; + struct tls_msg *tlm; struct sk_buff *skb; ssize_t copied = 0; bool cmsg = false; @@ -1632,7 +1671,8 @@ int tls_sw_recvmsg(struct sock *sk, lock_sock(sk); /* Process pending decrypted records. It must be non-zero-copy */ - err = process_rx_list(ctx, msg, 0, len, false, is_peek); + err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false, + is_peek); if (err < 0) { tls_err_abort(sk, err); goto end; @@ -1668,6 +1708,12 @@ int tls_sw_recvmsg(struct sock *sk, } } goto recv_end; + } else { + tlm = tls_msg(skb); + if (prot->version == TLS_1_3_VERSION) + tlm->control = 0; + else + tlm->control = ctx->control; } rxm = strp_msg(skb); @@ -1694,22 +1740,34 @@ int tls_sw_recvmsg(struct sock *sk, if (err == -EINPROGRESS) num_async++; + else if (prot->version == TLS_1_3_VERSION) + tlm->control = ctx->control; + + /* If the type of records being processed is not known yet, + * set it to record type just dequeued. If it is already known, + * but does not match the record type just dequeued, go to end. + * We always get record type here since for tls1.2, record type + * is known just after record is dequeued from stream parser. + * For tls1.3, we disable async. + */ + + if (!control) + control = tlm->control; + else if (control != tlm->control) + goto recv_end; if (!cmsg) { int cerr; cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, - sizeof(ctx->control), &ctx->control); + sizeof(control), &control); cmsg = true; - control = ctx->control; - if (ctx->control != TLS_RECORD_TYPE_DATA) { + if (control != TLS_RECORD_TYPE_DATA) { if (cerr || msg->msg_flags & MSG_CTRUNC) { err = -EIO; goto recv_end; } } - } else if (control != ctx->control) { - goto recv_end; } if (async) @@ -1784,18 +1842,16 @@ int tls_sw_recvmsg(struct sock *sk, /* Drain records from the rx_list & copy if required */ if (is_peek || is_kvec) - err = process_rx_list(ctx, msg, copied, + err = process_rx_list(ctx, msg, &control, &cmsg, copied, decrypted, false, is_peek); else - err = process_rx_list(ctx, msg, 0, + err = process_rx_list(ctx, msg, &control, &cmsg, 0, decrypted, true, is_peek); if (err < 0) { tls_err_abort(sk, err); copied = 0; goto end; } - - WARN_ON(decrypted != err); } copied += decrypted;