diff --git a/include/net/tls.h b/include/net/tls.h index 9f3c4ea9ad6f1b04ccfa2fe0657c356e9cf54123..3aa73e2d8823581ba6cf6ad0d69ce6eb9184f71e 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -41,7 +41,7 @@ #include #include #include - +#include #include @@ -93,24 +93,47 @@ enum { TLS_NUM_CONFIG, }; -struct tls_sw_context_tx { - struct crypto_aead *aead_send; - struct crypto_wait async_wait; - - char aad_space[TLS_AAD_SPACE_SIZE]; - - unsigned int sg_plaintext_size; - int sg_plaintext_num_elem; +/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages + * allocated or mapped for each TLS record. After encryption, the records are + * stores in a linked list. + */ +struct tls_rec { + struct list_head list; + int tx_flags; struct scatterlist sg_plaintext_data[MAX_SKB_FRAGS]; - - unsigned int sg_encrypted_size; - int sg_encrypted_num_elem; struct scatterlist sg_encrypted_data[MAX_SKB_FRAGS]; /* AAD | sg_plaintext_data | sg_tag */ struct scatterlist sg_aead_in[2]; /* AAD | sg_encrypted_data (data contain overhead for hdr&iv&tag) */ struct scatterlist sg_aead_out[2]; + + unsigned int sg_plaintext_size; + unsigned int sg_encrypted_size; + int sg_plaintext_num_elem; + int sg_encrypted_num_elem; + + char aad_space[TLS_AAD_SPACE_SIZE]; + struct aead_request aead_req; + u8 aead_req_ctx[]; +}; + +struct tx_work { + struct delayed_work work; + struct sock *sk; +}; + +struct tls_sw_context_tx { + struct crypto_aead *aead_send; + struct crypto_wait async_wait; + struct tx_work tx_work; + struct tls_rec *open_rec; + struct list_head tx_ready_list; + atomic_t encrypt_pending; + int async_notify; + +#define BIT_TX_SCHEDULED 0 + unsigned long tx_bitmask; }; struct tls_sw_context_rx { @@ -197,6 +220,8 @@ struct tls_context { struct scatterlist *partially_sent_record; u16 partially_sent_offset; + u64 tx_seq_number; /* Next TLS seqnum to be transmitted */ + unsigned long flags; bool in_tcp_sendpages; @@ -261,6 +286,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page, void tls_device_sk_destruct(struct sock *sk); void tls_device_init(void); void tls_device_cleanup(void); +int tls_tx_records(struct sock *sk, int flags); struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, u32 seq, u64 *p_record_sn); @@ -279,6 +305,9 @@ void tls_sk_destruct(struct sock *sk, struct tls_context *ctx); int tls_push_sg(struct sock *sk, struct tls_context *ctx, struct scatterlist *sg, u16 first_offset, int flags); +int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, + int flags); + int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, int flags, long *timeo); @@ -312,6 +341,23 @@ static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) return tls_ctx->pending_open_record_frags; } +static inline bool is_tx_ready(struct tls_context *tls_ctx, + struct tls_sw_context_tx *ctx) +{ + struct tls_rec *rec; + u64 seq; + + rec = list_first_entry(&ctx->tx_ready_list, struct tls_rec, list); + if (!rec) + return false; + + seq = be64_to_cpup((const __be64 *)&rec->aad_space); + if (seq == tls_ctx->tx_seq_number) + return true; + else + return false; +} + struct sk_buff * tls_validate_xmit_skb(struct sock *sk, struct net_device *dev, struct sk_buff *skb); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 523622dc74f8b969113b0435b39f5d0f3d070304..06094de7a3d92305a1b2af247018af1719dec075 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -141,7 +141,6 @@ int tls_push_sg(struct sock *sk, size = sg->length; } - clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); ctx->in_tcp_sendpages = false; ctx->sk_write_space(sk); @@ -193,15 +192,12 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, return rc; } -int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, - int flags, long *timeo) +int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, + int flags) { struct scatterlist *sg; u16 offset; - if (!tls_is_partially_sent_record(ctx)) - return ctx->push_pending_record(sk, flags); - sg = ctx->partially_sent_record; offset = ctx->partially_sent_offset; @@ -209,9 +205,23 @@ int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } +int tls_push_pending_closed_record(struct sock *sk, + struct tls_context *tls_ctx, + int flags, long *timeo) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + if (tls_is_partially_sent_record(tls_ctx) || + !list_empty(&ctx->tx_ready_list)) + return tls_tx_records(sk, flags); + else + return tls_ctx->push_pending_record(sk, flags); +} + static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); /* If in_tcp_sendpages call lower protocol write space handler * to ensure we wake up any waiting operations there. For example @@ -222,20 +232,11 @@ static void tls_write_space(struct sock *sk) return; } - if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { - gfp_t sk_allocation = sk->sk_allocation; - int rc; - long timeo = 0; - - sk->sk_allocation = GFP_ATOMIC; - rc = tls_push_pending_closed_record(sk, ctx, - MSG_DONTWAIT | - MSG_NOSIGNAL, - &timeo); - sk->sk_allocation = sk_allocation; - - if (rc < 0) - return; + /* Schedule the transmission if tx list is ready */ + if (is_tx_ready(ctx, tx_ctx) && !sk->sk_write_pending) { + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); } ctx->sk_write_space(sk); @@ -270,19 +271,6 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) tls_handle_open_record(sk, 0); - if (ctx->partially_sent_record) { - struct scatterlist *sg = ctx->partially_sent_record; - - while (1) { - put_page(sg_page(sg)); - sk_mem_uncharge(sk, sg->length); - - if (sg_is_last(sg)) - break; - sg++; - } - } - /* We need these for tls_sw_fallback handling of other packets */ if (ctx->tx_conf == TLS_SW) { kfree(ctx->tx.rec_seq); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 5ff51bac8b469903b70ec9304ef06cf60d265ed0..bcb24c498b8417638d35fc8b1e0e72cb5eb7ab26 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -246,18 +246,19 @@ static void trim_both_sgl(struct sock *sk, int target_size) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; - trim_sg(sk, ctx->sg_plaintext_data, - &ctx->sg_plaintext_num_elem, - &ctx->sg_plaintext_size, + trim_sg(sk, rec->sg_plaintext_data, + &rec->sg_plaintext_num_elem, + &rec->sg_plaintext_size, target_size); if (target_size > 0) target_size += tls_ctx->tx.overhead_size; - trim_sg(sk, ctx->sg_encrypted_data, - &ctx->sg_encrypted_num_elem, - &ctx->sg_encrypted_size, + trim_sg(sk, rec->sg_encrypted_data, + &rec->sg_encrypted_num_elem, + &rec->sg_encrypted_size, target_size); } @@ -265,15 +266,16 @@ static int alloc_encrypted_sg(struct sock *sk, int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; int rc = 0; rc = sk_alloc_sg(sk, len, - ctx->sg_encrypted_data, 0, - &ctx->sg_encrypted_num_elem, - &ctx->sg_encrypted_size, 0); + rec->sg_encrypted_data, 0, + &rec->sg_encrypted_num_elem, + &rec->sg_encrypted_size, 0); if (rc == -ENOSPC) - ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data); + rec->sg_encrypted_num_elem = ARRAY_SIZE(rec->sg_encrypted_data); return rc; } @@ -282,14 +284,15 @@ static int alloc_plaintext_sg(struct sock *sk, int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; int rc = 0; - rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, - &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, + rc = sk_alloc_sg(sk, len, rec->sg_plaintext_data, 0, + &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size, tls_ctx->pending_open_record_frags); if (rc == -ENOSPC) - ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data); + rec->sg_plaintext_num_elem = ARRAY_SIZE(rec->sg_plaintext_data); return rc; } @@ -311,37 +314,192 @@ static void tls_free_both_sg(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; - free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, - &ctx->sg_encrypted_size); + /* Return if there is no open record */ + if (!rec) + return; + + free_sg(sk, rec->sg_encrypted_data, + &rec->sg_encrypted_num_elem, + &rec->sg_encrypted_size); + + free_sg(sk, rec->sg_plaintext_data, + &rec->sg_plaintext_num_elem, + &rec->sg_plaintext_size); +} + +static bool append_tx_ready_list(struct tls_context *tls_ctx, + struct tls_sw_context_tx *ctx, + struct tls_rec *enc_rec) +{ + u64 new_seq = be64_to_cpup((const __be64 *)&enc_rec->aad_space); + struct list_head *pos; + + /* Need to insert encrypted record in tx_ready_list sorted + * as per sequence number. Traverse linked list from tail. + */ + list_for_each_prev(pos, &ctx->tx_ready_list) { + struct tls_rec *rec = (struct tls_rec *)pos; + u64 seq = be64_to_cpup((const __be64 *)&rec->aad_space); + + if (new_seq > seq) + break; + } + + list_add((struct list_head *)&enc_rec->list, pos); + + return is_tx_ready(tls_ctx, ctx); +} + +int tls_tx_records(struct sock *sk, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec, *tmp; + int tx_flags, rc = 0; + + if (tls_is_partially_sent_record(tls_ctx)) { + rec = list_first_entry(&ctx->tx_ready_list, + struct tls_rec, list); + + if (flags == -1) + tx_flags = rec->tx_flags; + else + tx_flags = flags; + + rc = tls_push_partial_record(sk, tls_ctx, tx_flags); + if (rc) + goto tx_err; + + /* Full record has been transmitted. + * Remove the head of tx_ready_list + */ + tls_ctx->tx_seq_number++; + list_del(&rec->list); + kfree(rec); + } + + /* Tx all ready records which have expected sequence number */ + list_for_each_entry_safe(rec, tmp, &ctx->tx_ready_list, list) { + u64 seq = be64_to_cpup((const __be64 *)&rec->aad_space); + + if (seq == tls_ctx->tx_seq_number) { + if (flags == -1) + tx_flags = rec->tx_flags; + else + tx_flags = flags; + + rc = tls_push_sg(sk, tls_ctx, + &rec->sg_encrypted_data[0], + 0, tx_flags); + if (rc) + goto tx_err; + + tls_ctx->tx_seq_number++; + list_del(&rec->list); + kfree(rec); + } else { + break; + } + } + +tx_err: + if (rc < 0 && rc != -EAGAIN) + tls_err_abort(sk, EBADMSG); + + return rc; +} + +static void tls_encrypt_done(struct crypto_async_request *req, int err) +{ + struct aead_request *aead_req = (struct aead_request *)req; + struct sock *sk = req->data; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec; + bool ready = false; + int pending; + + rec = container_of(aead_req, struct tls_rec, aead_req); + + rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; + rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; + + free_sg(sk, rec->sg_plaintext_data, + &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size); + + /* Free the record if error is previously set on socket */ + if (err || sk->sk_err) { + free_sg(sk, rec->sg_encrypted_data, + &rec->sg_encrypted_num_elem, &rec->sg_encrypted_size); - free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, - &ctx->sg_plaintext_size); + kfree(rec); + rec = NULL; + + /* If err is already set on socket, return the same code */ + if (sk->sk_err) { + ctx->async_wait.err = sk->sk_err; + } else { + ctx->async_wait.err = err; + tls_err_abort(sk, err); + } + } + + /* Append the record in tx queue */ + if (rec) + ready = append_tx_ready_list(tls_ctx, ctx, rec); + + pending = atomic_dec_return(&ctx->encrypt_pending); + + if (!pending && READ_ONCE(ctx->async_notify)) + complete(&ctx->async_wait.completion); + + if (!ready) + return; + + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + schedule_delayed_work(&ctx->tx_work.work, 1); } -static int tls_do_encryption(struct tls_context *tls_ctx, +static int tls_do_encryption(struct sock *sk, + struct tls_context *tls_ctx, struct tls_sw_context_tx *ctx, struct aead_request *aead_req, size_t data_len) { + struct tls_rec *rec = ctx->open_rec; int rc; - ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size; - ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; + rec->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size; + rec->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; aead_request_set_tfm(aead_req, ctx->aead_send); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); - aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out, + aead_request_set_crypt(aead_req, rec->sg_aead_in, + rec->sg_aead_out, data_len, tls_ctx->tx.iv); aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - crypto_req_done, &ctx->async_wait); + tls_encrypt_done, sk); + + atomic_inc(&ctx->encrypt_pending); - rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait); + rc = crypto_aead_encrypt(aead_req); + if (!rc || rc != -EINPROGRESS) { + atomic_dec(&ctx->encrypt_pending); + rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; + rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; + } - ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; - ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; + /* Case of encryption failure */ + if (rc && rc != -EINPROGRESS) + return rc; + /* Unhook the record from context if encryption is not failure */ + ctx->open_rec = NULL; + tls_advance_record_sn(sk, &tls_ctx->tx); return rc; } @@ -350,53 +508,49 @@ static int tls_push_record(struct sock *sk, int flags, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; struct aead_request *req; int rc; - req = aead_request_alloc(ctx->aead_send, sk->sk_allocation); - if (!req) - return -ENOMEM; + if (!rec) + return 0; - sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); - sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); + rec->tx_flags = flags; + req = &rec->aead_req; - tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, + sg_mark_end(rec->sg_plaintext_data + rec->sg_plaintext_num_elem - 1); + sg_mark_end(rec->sg_encrypted_data + rec->sg_encrypted_num_elem - 1); + + tls_make_aad(rec->aad_space, rec->sg_plaintext_size, tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, record_type); tls_fill_prepend(tls_ctx, - page_address(sg_page(&ctx->sg_encrypted_data[0])) + - ctx->sg_encrypted_data[0].offset, - ctx->sg_plaintext_size, record_type); + page_address(sg_page(&rec->sg_encrypted_data[0])) + + rec->sg_encrypted_data[0].offset, + rec->sg_plaintext_size, record_type); tls_ctx->pending_open_record_frags = 0; - set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags); - - rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size); - if (rc < 0) { - /* If we are called from write_space and - * we fail, we need to set this SOCK_NOSPACE - * to trigger another write_space in the future. - */ - set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); - goto out_req; - } - free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, - &ctx->sg_plaintext_size); + rc = tls_do_encryption(sk, tls_ctx, ctx, req, rec->sg_plaintext_size); + if (rc == -EINPROGRESS) + return -EINPROGRESS; - ctx->sg_encrypted_num_elem = 0; - ctx->sg_encrypted_size = 0; + free_sg(sk, rec->sg_plaintext_data, &rec->sg_plaintext_num_elem, + &rec->sg_plaintext_size); - /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */ - rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags); - if (rc < 0 && rc != -EAGAIN) + if (rc < 0) { tls_err_abort(sk, EBADMSG); + return rc; + } - tls_advance_record_sn(sk, &tls_ctx->tx); -out_req: - aead_request_free(req); - return rc; + /* Put the record in tx_ready_list and start tx if permitted. + * This happens only when encryption is not asynchronous. + */ + if (append_tx_ready_list(tls_ctx, ctx, rec)) + return tls_tx_records(sk, flags); + + return 0; } static int tls_sw_push_pending_record(struct sock *sk, int flags) @@ -473,11 +627,12 @@ static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct scatterlist *sg = ctx->sg_plaintext_data; + struct tls_rec *rec = ctx->open_rec; + struct scatterlist *sg = rec->sg_plaintext_data; int copy, i, rc = 0; for (i = tls_ctx->pending_open_record_frags; - i < ctx->sg_plaintext_num_elem; ++i) { + i < rec->sg_plaintext_num_elem; ++i) { copy = sg[i].length; if (copy_from_iter( page_address(sg_page(&sg[i])) + sg[i].offset, @@ -497,34 +652,85 @@ static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, return rc; } -int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +struct tls_rec *get_rec(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - int ret; - int required_size; + struct tls_rec *rec; + int mem_size; + + /* Return if we already have an open record */ + if (ctx->open_rec) + return ctx->open_rec; + + mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); + + rec = kzalloc(mem_size, sk->sk_allocation); + if (!rec) + return NULL; + + sg_init_table(&rec->sg_plaintext_data[0], + ARRAY_SIZE(rec->sg_plaintext_data)); + sg_init_table(&rec->sg_encrypted_data[0], + ARRAY_SIZE(rec->sg_encrypted_data)); + + sg_init_table(rec->sg_aead_in, 2); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, + sizeof(rec->aad_space)); + sg_unmark_end(&rec->sg_aead_in[1]); + sg_chain(rec->sg_aead_in, 2, rec->sg_plaintext_data); + + sg_init_table(rec->sg_aead_out, 2); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, + sizeof(rec->aad_space)); + sg_unmark_end(&rec->sg_aead_out[1]); + sg_chain(rec->sg_aead_out, 2, rec->sg_encrypted_data); + + ctx->open_rec = rec; + + return rec; +} + +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send); + bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; + unsigned char record_type = TLS_RECORD_TYPE_DATA; + bool is_kvec = msg->msg_iter.type & ITER_KVEC; bool eor = !(msg->msg_flags & MSG_MORE); size_t try_to_copy, copied = 0; - unsigned char record_type = TLS_RECORD_TYPE_DATA; - int record_room; + struct tls_rec *rec; + int required_size; + int num_async = 0; bool full_record; + int record_room; + int num_zc = 0; int orig_size; - bool is_kvec = msg->msg_iter.type & ITER_KVEC; + int ret; if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -ENOTSUPP; lock_sock(sk); - ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo); - if (ret) - goto send_end; + /* Wait till there is any pending write on socket */ + if (unlikely(sk->sk_write_pending)) { + ret = wait_on_pending_writer(sk, &timeo); + if (unlikely(ret)) + goto send_end; + } if (unlikely(msg->msg_controllen)) { ret = tls_proccess_cmsg(sk, msg, &record_type); - if (ret) - goto send_end; + if (ret) { + if (ret == -EINPROGRESS) + num_async++; + else if (ret != -EAGAIN) + goto send_end; + } } while (msg_data_left(msg)) { @@ -533,20 +739,27 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) goto send_end; } - orig_size = ctx->sg_plaintext_size; + rec = get_rec(sk); + if (!rec) { + ret = -ENOMEM; + goto send_end; + } + + orig_size = rec->sg_plaintext_size; full_record = false; try_to_copy = msg_data_left(msg); - record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; + record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size; if (try_to_copy >= record_room) { try_to_copy = record_room; full_record = true; } - required_size = ctx->sg_plaintext_size + try_to_copy + + required_size = rec->sg_plaintext_size + try_to_copy + tls_ctx->tx.overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; + alloc_encrypted: ret = alloc_encrypted_sg(sk, required_size); if (ret) { @@ -557,33 +770,39 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) * actually allocated. The difference is due * to max sg elements limit */ - try_to_copy -= required_size - ctx->sg_encrypted_size; + try_to_copy -= required_size - rec->sg_encrypted_size; full_record = true; } - if (!is_kvec && (full_record || eor)) { + + if (!is_kvec && (full_record || eor) && !async_capable) { ret = zerocopy_from_iter(sk, &msg->msg_iter, - try_to_copy, &ctx->sg_plaintext_num_elem, - &ctx->sg_plaintext_size, - ctx->sg_plaintext_data, - ARRAY_SIZE(ctx->sg_plaintext_data), + try_to_copy, &rec->sg_plaintext_num_elem, + &rec->sg_plaintext_size, + rec->sg_plaintext_data, + ARRAY_SIZE(rec->sg_plaintext_data), true); if (ret) goto fallback_to_reg_send; + num_zc++; copied += try_to_copy; ret = tls_push_record(sk, msg->msg_flags, record_type); - if (ret) - goto send_end; + if (ret) { + if (ret == -EINPROGRESS) + num_async++; + else if (ret != -EAGAIN) + goto send_end; + } continue; fallback_to_reg_send: - trim_sg(sk, ctx->sg_plaintext_data, - &ctx->sg_plaintext_num_elem, - &ctx->sg_plaintext_size, + trim_sg(sk, rec->sg_plaintext_data, + &rec->sg_plaintext_num_elem, + &rec->sg_plaintext_size, orig_size); } - required_size = ctx->sg_plaintext_size + try_to_copy; + required_size = rec->sg_plaintext_size + try_to_copy; alloc_plaintext: ret = alloc_plaintext_sg(sk, required_size); if (ret) { @@ -594,13 +813,13 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) * actually allocated. The difference is due * to max sg elements limit */ - try_to_copy -= required_size - ctx->sg_plaintext_size; + try_to_copy -= required_size - rec->sg_plaintext_size; full_record = true; - trim_sg(sk, ctx->sg_encrypted_data, - &ctx->sg_encrypted_num_elem, - &ctx->sg_encrypted_size, - ctx->sg_plaintext_size + + trim_sg(sk, rec->sg_encrypted_data, + &rec->sg_encrypted_num_elem, + &rec->sg_encrypted_size, + rec->sg_plaintext_size + tls_ctx->tx.overhead_size); } @@ -610,13 +829,12 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) copied += try_to_copy; if (full_record || eor) { -push_record: ret = tls_push_record(sk, msg->msg_flags, record_type); if (ret) { - if (ret == -ENOMEM) - goto wait_for_memory; - - goto send_end; + if (ret == -EINPROGRESS) + num_async++; + else if (ret != -EAGAIN) + goto send_end; } } @@ -632,15 +850,37 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) goto send_end; } - if (tls_is_pending_closed_record(tls_ctx)) - goto push_record; - - if (ctx->sg_encrypted_size < required_size) + if (rec->sg_encrypted_size < required_size) goto alloc_encrypted; goto alloc_plaintext; } + if (!num_async) { + goto send_end; + } else if (num_zc) { + /* Wait for pending encryptions to get completed */ + smp_store_mb(ctx->async_notify, true); + + if (atomic_read(&ctx->encrypt_pending)) + crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + else + reinit_completion(&ctx->async_wait.completion); + + WRITE_ONCE(ctx->async_notify, false); + + if (ctx->async_wait.err) { + ret = ctx->async_wait.err; + copied = 0; + } + } + + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, msg->msg_flags); + } + send_end: ret = sk_stream_error(sk, msg->msg_flags, ret); @@ -651,16 +891,18 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { + long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - int ret; - long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); - bool eor; - size_t orig_size = size; unsigned char record_type = TLS_RECORD_TYPE_DATA; + size_t orig_size = size; struct scatterlist *sg; + struct tls_rec *rec; + int num_async = 0; bool full_record; int record_room; + bool eor; + int ret; if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST)) @@ -673,9 +915,12 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); - ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); - if (ret) - goto sendpage_end; + /* Wait till there is any pending write on socket */ + if (unlikely(sk->sk_write_pending)) { + ret = wait_on_pending_writer(sk, &timeo); + if (unlikely(ret)) + goto sendpage_end; + } /* Call the sk_stream functions to manage the sndbuf mem. */ while (size > 0) { @@ -686,14 +931,20 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, goto sendpage_end; } + rec = get_rec(sk); + if (!rec) { + ret = -ENOMEM; + goto sendpage_end; + } + full_record = false; - record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; + record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size; copy = size; if (copy >= record_room) { copy = record_room; full_record = true; } - required_size = ctx->sg_plaintext_size + copy + + required_size = rec->sg_plaintext_size + copy + tls_ctx->tx.overhead_size; if (!sk_stream_memory_free(sk)) @@ -708,33 +959,32 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, * actually allocated. The difference is due * to max sg elements limit */ - copy -= required_size - ctx->sg_plaintext_size; + copy -= required_size - rec->sg_plaintext_size; full_record = true; } get_page(page); - sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem; + sg = rec->sg_plaintext_data + rec->sg_plaintext_num_elem; sg_set_page(sg, page, copy, offset); sg_unmark_end(sg); - ctx->sg_plaintext_num_elem++; + rec->sg_plaintext_num_elem++; sk_mem_charge(sk, copy); offset += copy; size -= copy; - ctx->sg_plaintext_size += copy; - tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem; + rec->sg_plaintext_size += copy; + tls_ctx->pending_open_record_frags = rec->sg_plaintext_num_elem; if (full_record || eor || - ctx->sg_plaintext_num_elem == - ARRAY_SIZE(ctx->sg_plaintext_data)) { -push_record: + rec->sg_plaintext_num_elem == + ARRAY_SIZE(rec->sg_plaintext_data)) { ret = tls_push_record(sk, flags, record_type); if (ret) { - if (ret == -ENOMEM) - goto wait_for_memory; - - goto sendpage_end; + if (ret == -EINPROGRESS) + num_async++; + else if (ret != -EAGAIN) + goto sendpage_end; } } continue; @@ -743,16 +993,20 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { - trim_both_sgl(sk, ctx->sg_plaintext_size); + trim_both_sgl(sk, rec->sg_plaintext_size); goto sendpage_end; } - if (tls_is_pending_closed_record(tls_ctx)) - goto push_record; - goto alloc_payload; } + if (num_async) { + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, flags); + } + } sendpage_end: if (orig_size > size) ret = orig_size - size; @@ -1300,6 +1554,49 @@ void tls_sw_free_resources_tx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec, *tmp; + + /* Wait for any pending async encryptions to complete */ + smp_store_mb(ctx->async_notify, true); + if (atomic_read(&ctx->encrypt_pending)) + crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + + cancel_delayed_work_sync(&ctx->tx_work.work); + + /* Tx whatever records we can transmit and abandon the rest */ + tls_tx_records(sk, -1); + + /* Free up un-sent records in tx_ready_list. First, free + * the partially sent record if any at head of tx_list. + */ + if (tls_ctx->partially_sent_record) { + struct scatterlist *sg = tls_ctx->partially_sent_record; + + while (1) { + put_page(sg_page(sg)); + sk_mem_uncharge(sk, sg->length); + + if (sg_is_last(sg)) + break; + sg++; + } + + tls_ctx->partially_sent_record = NULL; + + rec = list_first_entry(&ctx->tx_ready_list, + struct tls_rec, list); + list_del(&rec->list); + kfree(rec); + } + + list_for_each_entry_safe(rec, tmp, &ctx->tx_ready_list, list) { + free_sg(sk, rec->sg_encrypted_data, + &rec->sg_encrypted_num_elem, + &rec->sg_encrypted_size); + + list_del(&rec->list); + kfree(rec); + } crypto_free_aead(ctx->aead_send); tls_free_both_sg(sk); @@ -1336,6 +1633,24 @@ void tls_sw_free_resources_rx(struct sock *sk) kfree(ctx); } +/* The work handler to transmitt the encrypted records in tx_ready_list */ +static void tx_work_handler(struct work_struct *work) +{ + struct delayed_work *delayed_work = to_delayed_work(work); + struct tx_work *tx_work = container_of(delayed_work, + struct tx_work, work); + struct sock *sk = tx_work->sk; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + return; + + lock_sock(sk); + tls_tx_records(sk, -1); + release_sock(sk); +} + int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { struct tls_crypto_info *crypto_info; @@ -1385,6 +1700,9 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) crypto_info = &ctx->crypto_send.info; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; + INIT_LIST_HEAD(&sw_ctx_tx->tx_ready_list); + INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); + sw_ctx_tx->tx_work.sk = sk; } else { crypto_init_wait(&sw_ctx_rx->async_wait); crypto_info = &ctx->crypto_recv.info; @@ -1435,26 +1753,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto free_iv; } - if (sw_ctx_tx) { - sg_init_table(sw_ctx_tx->sg_encrypted_data, - ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data)); - sg_init_table(sw_ctx_tx->sg_plaintext_data, - ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data)); - - sg_init_table(sw_ctx_tx->sg_aead_in, 2); - sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space, - sizeof(sw_ctx_tx->aad_space)); - sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]); - sg_chain(sw_ctx_tx->sg_aead_in, 2, - sw_ctx_tx->sg_plaintext_data); - sg_init_table(sw_ctx_tx->sg_aead_out, 2); - sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space, - sizeof(sw_ctx_tx->aad_space)); - sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]); - sg_chain(sw_ctx_tx->sg_aead_out, 2, - sw_ctx_tx->sg_encrypted_data); - } - if (!*aead) { *aead = crypto_alloc_aead("gcm(aes)", 0, 0); if (IS_ERR(*aead)) { @@ -1491,6 +1789,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; strp_check_rcv(&sw_ctx_rx->strp); + } else { + ctx->tx_seq_number = be64_to_cpup((const __be64 *)rec_seq); } goto out;