tls_sw.c 63.2 KB
Newer Older
D
Dave Watson 已提交
1 2 3 4 5 6
/*
 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7
 * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
D
Dave Watson 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

38
#include <linux/bug.h>
D
Dave Watson 已提交
39
#include <linux/sched/signal.h>
D
Dave Watson 已提交
40
#include <linux/module.h>
41
#include <linux/splice.h>
D
Dave Watson 已提交
42 43
#include <crypto/aead.h>

D
Dave Watson 已提交
44
#include <net/strparser.h>
D
Dave Watson 已提交
45 46
#include <net/tls.h>

47 48 49 50 51
struct tls_decrypt_arg {
	bool zc;
	bool async;
};

52 53 54 55 56 57 58 59
noinline void tls_err_abort(struct sock *sk, int err)
{
	WARN_ON_ONCE(err >= 0);
	/* sk->sk_err should contain a positive error code. */
	sk->sk_err = -err;
	sk_error_report(sk);
}

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
static int __skb_nsg(struct sk_buff *skb, int offset, int len,
                     unsigned int recursion_level)
{
        int start = skb_headlen(skb);
        int i, chunk = start - offset;
        struct sk_buff *frag_iter;
        int elt = 0;

        if (unlikely(recursion_level >= 24))
                return -EMSGSIZE;

        if (chunk > 0) {
                if (chunk > len)
                        chunk = len;
                elt++;
                len -= chunk;
                if (len == 0)
                        return elt;
                offset += chunk;
        }

        for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
                int end;

                WARN_ON(start > offset + len);

                end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
                chunk = end - offset;
                if (chunk > 0) {
                        if (chunk > len)
                                chunk = len;
                        elt++;
                        len -= chunk;
                        if (len == 0)
                                return elt;
                        offset += chunk;
                }
                start = end;
        }

        if (unlikely(skb_has_frag_list(skb))) {
                skb_walk_frags(skb, frag_iter) {
                        int end, ret;

                        WARN_ON(start > offset + len);

                        end = start + frag_iter->len;
                        chunk = end - offset;
                        if (chunk > 0) {
                                if (chunk > len)
                                        chunk = len;
                                ret = __skb_nsg(frag_iter, offset - start, chunk,
                                                recursion_level + 1);
                                if (unlikely(ret < 0))
                                        return ret;
                                elt += ret;
                                len -= chunk;
                                if (len == 0)
                                        return elt;
                                offset += chunk;
                        }
                        start = end;
                }
        }
        BUG_ON(len);
        return elt;
}

/* Return the number of scatterlist elements required to completely map the
 * skb, or -EMSGSIZE if the recursion depth is exceeded.
 */
static int skb_nsg(struct sk_buff *skb, int offset, int len)
{
        return __skb_nsg(skb, offset, len, 0);
}

136
static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
D
Dave Watson 已提交
137 138
{
	struct strp_msg *rxm = strp_msg(skb);
139
	struct tls_msg *tlm = tls_msg(skb);
D
Dave Watson 已提交
140 141 142
	int sub = 0;

	/* Determine zero-padding length */
143
	if (prot->version == TLS_1_3_VERSION) {
144
		int offset = rxm->full_len - TLS_TAG_SIZE - 1;
D
Dave Watson 已提交
145 146 147 148
		char content_type = 0;
		int err;

		while (content_type == 0) {
149
			if (offset < prot->prepend_size)
D
Dave Watson 已提交
150
				return -EBADMSG;
151
			err = skb_copy_bits(skb, rxm->offset + offset,
D
Dave Watson 已提交
152
					    &content_type, 1);
153 154
			if (err)
				return err;
D
Dave Watson 已提交
155 156 157
			if (content_type)
				break;
			sub++;
158
			offset--;
D
Dave Watson 已提交
159
		}
160
		tlm->control = content_type;
D
Dave Watson 已提交
161 162 163 164
	}
	return sub;
}

165 166 167 168
static void tls_decrypt_done(struct crypto_async_request *req, int err)
{
	struct aead_request *aead_req = (struct aead_request *)req;
	struct scatterlist *sgout = aead_req->dst;
169
	struct scatterlist *sgin = aead_req->src;
170 171
	struct tls_sw_context_rx *ctx;
	struct tls_context *tls_ctx;
172
	struct tls_prot_info *prot;
173
	struct scatterlist *sg;
174
	struct sk_buff *skb;
175
	unsigned int pages;
176 177 178 179

	skb = (struct sk_buff *)req->data;
	tls_ctx = tls_get_ctx(skb->sk);
	ctx = tls_sw_ctx_rx(tls_ctx);
180
	prot = &tls_ctx->prot_info;
181 182 183

	/* Propagate if there was an err */
	if (err) {
J
Jakub Kicinski 已提交
184 185 186
		if (err == -EBADMSG)
			TLS_INC_STATS(sock_net(skb->sk),
				      LINUX_MIB_TLSDECRYPTERROR);
187
		ctx->async_wait.err = err;
188
		tls_err_abort(skb->sk, err);
189 190
	} else {
		struct strp_msg *rxm = strp_msg(skb);
191 192 193 194 195 196

		/* No TLS 1.3 support with async crypto */
		WARN_ON(prot->tail_size);

		rxm->offset += prot->prepend_size;
		rxm->full_len -= prot->overhead_size;
197 198
	}

199 200 201 202 203
	/* After using skb->sk to propagate sk through crypto async callback
	 * we need to NULL it again.
	 */
	skb->sk = NULL;

204

205 206 207 208 209 210 211 212
	/* Free the destination pages if skb was not decrypted inplace */
	if (sgout != sgin) {
		/* Skip the first S/G entry as it points to AAD */
		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
			if (!sg)
				break;
			put_page(sg_page(sg));
		}
213 214 215 216
	}

	kfree(aead_req);

217
	spin_lock_bh(&ctx->decrypt_compl_lock);
J
Jakub Kicinski 已提交
218
	if (!atomic_dec_return(&ctx->decrypt_pending))
219
		complete(&ctx->async_wait.completion);
220
	spin_unlock_bh(&ctx->decrypt_compl_lock);
221 222
}

D
Dave Watson 已提交
223
static int tls_do_decryption(struct sock *sk,
224
			     struct sk_buff *skb,
D
Dave Watson 已提交
225 226 227 228
			     struct scatterlist *sgin,
			     struct scatterlist *sgout,
			     char *iv_recv,
			     size_t data_len,
229
			     struct aead_request *aead_req,
230
			     struct tls_decrypt_arg *darg)
D
Dave Watson 已提交
231 232
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
233
	struct tls_prot_info *prot = &tls_ctx->prot_info;
B
Boris Pismenny 已提交
234
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
235 236
	int ret;

237
	aead_request_set_tfm(aead_req, ctx->aead_recv);
238
	aead_request_set_ad(aead_req, prot->aad_size);
D
Dave Watson 已提交
239
	aead_request_set_crypt(aead_req, sgin, sgout,
240
			       data_len + prot->tag_size,
D
Dave Watson 已提交
241 242
			       (u8 *)iv_recv);

243
	if (darg->async) {
244 245 246
		/* Using skb->sk to push sk through to crypto async callback
		 * handler. This allows propagating errors up to the socket
		 * if needed. It _must_ be cleared in the async handler
247
		 * before consume_skb is called. We _know_ skb->sk is NULL
248 249 250
		 * because it is a clone from strparser.
		 */
		skb->sk = sk;
251 252 253 254 255 256 257 258 259 260 261 262
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  tls_decrypt_done, skb);
		atomic_inc(&ctx->decrypt_pending);
	} else {
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  crypto_req_done, &ctx->async_wait);
	}

	ret = crypto_aead_decrypt(aead_req);
	if (ret == -EINPROGRESS) {
263 264
		if (darg->async)
			return 0;
265 266 267

		ret = crypto_wait_req(ret, &ctx->async_wait);
	}
268 269
	darg->async = false;

270 271
	if (ret == -EBADMSG)
		TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
272

D
Dave Watson 已提交
273 274 275
	return ret;
}

276
static void tls_trim_both_msgs(struct sock *sk, int target_size)
D
Dave Watson 已提交
277 278
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
279
	struct tls_prot_info *prot = &tls_ctx->prot_info;
B
Boris Pismenny 已提交
280
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
281
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
282

283
	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
D
Dave Watson 已提交
284
	if (target_size > 0)
285
		target_size += prot->overhead_size;
286
	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
D
Dave Watson 已提交
287 288
}

289
static int tls_alloc_encrypted_msg(struct sock *sk, int len)
D
Dave Watson 已提交
290 291
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
292
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
293
	struct tls_rec *rec = ctx->open_rec;
294
	struct sk_msg *msg_en = &rec->msg_encrypted;
D
Dave Watson 已提交
295

296
	return sk_msg_alloc(sk, msg_en, len, 0);
D
Dave Watson 已提交
297 298
}

299
static int tls_clone_plaintext_msg(struct sock *sk, int required)
D
Dave Watson 已提交
300 301
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
302
	struct tls_prot_info *prot = &tls_ctx->prot_info;
B
Boris Pismenny 已提交
303
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
304
	struct tls_rec *rec = ctx->open_rec;
305 306
	struct sk_msg *msg_pl = &rec->msg_plaintext;
	struct sk_msg *msg_en = &rec->msg_encrypted;
307
	int skip, len;
D
Dave Watson 已提交
308

309 310
	/* We add page references worth len bytes from encrypted sg
	 * at the end of plaintext sg. It is guaranteed that msg_en
311 312
	 * has enough required room (ensured by caller).
	 */
313
	len = required - msg_pl->sg.size;
314

315 316
	/* Skip initial bytes in msg_en's data to be able to use
	 * same offset of both plain and encrypted data.
317
	 */
318
	skip = prot->prepend_size + msg_pl->sg.size;
319

320
	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
D
Dave Watson 已提交
321 322
}

323
static struct tls_rec *tls_get_rec(struct sock *sk)
D
Dave Watson 已提交
324 325
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
326
	struct tls_prot_info *prot = &tls_ctx->prot_info;
B
Boris Pismenny 已提交
327
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
328 329 330
	struct sk_msg *msg_pl, *msg_en;
	struct tls_rec *rec;
	int mem_size;
D
Dave Watson 已提交
331

332 333 334
	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);

	rec = kzalloc(mem_size, sk->sk_allocation);
335
	if (!rec)
336 337 338 339 340 341 342 343 344
		return NULL;

	msg_pl = &rec->msg_plaintext;
	msg_en = &rec->msg_encrypted;

	sk_msg_init(msg_pl);
	sk_msg_init(msg_en);

	sg_init_table(rec->sg_aead_in, 2);
345
	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
346 347 348
	sg_unmark_end(&rec->sg_aead_in[1]);

	sg_init_table(rec->sg_aead_out, 2);
349
	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
350 351 352 353
	sg_unmark_end(&rec->sg_aead_out[1]);

	return rec;
}
354

355 356
static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
{
357 358
	sk_msg_free(sk, &rec->msg_encrypted);
	sk_msg_free(sk, &rec->msg_plaintext);
359
	kfree(rec);
360 361
}

362 363 364 365 366 367 368 369 370 371 372 373
static void tls_free_open_rec(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec = ctx->open_rec;

	if (rec) {
		tls_free_rec(sk, rec);
		ctx->open_rec = NULL;
	}
}

374 375 376 377 378
int tls_tx_records(struct sock *sk, int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec, *tmp;
379
	struct sk_msg *msg_en;
380 381 382
	int tx_flags, rc = 0;

	if (tls_is_partially_sent_record(tls_ctx)) {
383
		rec = list_first_entry(&ctx->tx_list,
384 385 386 387 388 389 390 391 392 393 394 395
				       struct tls_rec, list);

		if (flags == -1)
			tx_flags = rec->tx_flags;
		else
			tx_flags = flags;

		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
		if (rc)
			goto tx_err;

		/* Full record has been transmitted.
396
		 * Remove the head of tx_list
397 398
		 */
		list_del(&rec->list);
399
		sk_msg_free(sk, &rec->msg_plaintext);
400 401 402
		kfree(rec);
	}

403 404 405
	/* Tx all ready records */
	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
		if (READ_ONCE(rec->tx_ready)) {
406 407 408 409 410
			if (flags == -1)
				tx_flags = rec->tx_flags;
			else
				tx_flags = flags;

411
			msg_en = &rec->msg_encrypted;
412
			rc = tls_push_sg(sk, tls_ctx,
413
					 &msg_en->sg.data[msg_en->sg.curr],
414 415 416 417 418
					 0, tx_flags);
			if (rc)
				goto tx_err;

			list_del(&rec->list);
419
			sk_msg_free(sk, &rec->msg_plaintext);
420 421 422 423 424 425 426 427
			kfree(rec);
		} else {
			break;
		}
	}

tx_err:
	if (rc < 0 && rc != -EAGAIN)
428
		tls_err_abort(sk, -EBADMSG);
429 430 431 432 433 434 435 436 437

	return rc;
}

static void tls_encrypt_done(struct crypto_async_request *req, int err)
{
	struct aead_request *aead_req = (struct aead_request *)req;
	struct sock *sk = req->data;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
438
	struct tls_prot_info *prot = &tls_ctx->prot_info;
439
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
440 441
	struct scatterlist *sge;
	struct sk_msg *msg_en;
442 443 444 445 446
	struct tls_rec *rec;
	bool ready = false;
	int pending;

	rec = container_of(aead_req, struct tls_rec, aead_req);
447
	msg_en = &rec->msg_encrypted;
448

449
	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
450 451
	sge->offset -= prot->prepend_size;
	sge->length += prot->prepend_size;
452

453
	/* Check if error is previously set on socket */
454 455 456 457 458
	if (err || sk->sk_err) {
		rec = NULL;

		/* If err is already set on socket, return the same code */
		if (sk->sk_err) {
459
			ctx->async_wait.err = -sk->sk_err;
460 461 462 463 464 465
		} else {
			ctx->async_wait.err = err;
			tls_err_abort(sk, err);
		}
	}

466 467 468 469 470 471 472 473 474 475 476 477
	if (rec) {
		struct tls_rec *first_rec;

		/* Mark the record as ready for transmission */
		smp_store_mb(rec->tx_ready, true);

		/* If received record is at head of tx_list, schedule tx */
		first_rec = list_first_entry(&ctx->tx_list,
					     struct tls_rec, list);
		if (rec == first_rec)
			ready = true;
	}
478

479
	spin_lock_bh(&ctx->encrypt_compl_lock);
480 481
	pending = atomic_dec_return(&ctx->encrypt_pending);

482
	if (!pending && ctx->async_notify)
483
		complete(&ctx->async_wait.completion);
484
	spin_unlock_bh(&ctx->encrypt_compl_lock);
485 486 487 488 489 490

	if (!ready)
		return;

	/* Schedule the transmission */
	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
491
		schedule_delayed_work(&ctx->tx_work.work, 1);
D
Dave Watson 已提交
492 493
}

494 495
static int tls_do_encryption(struct sock *sk,
			     struct tls_context *tls_ctx,
496 497
			     struct tls_sw_context_tx *ctx,
			     struct aead_request *aead_req,
498
			     size_t data_len, u32 start)
D
Dave Watson 已提交
499
{
500
	struct tls_prot_info *prot = &tls_ctx->prot_info;
501
	struct tls_rec *rec = ctx->open_rec;
502 503
	struct sk_msg *msg_en = &rec->msg_encrypted;
	struct scatterlist *sge = sk_msg_elem(msg_en, start);
504 505 506
	int rc, iv_offset = 0;

	/* For CCM based ciphers, first byte of IV is a constant */
507 508
	switch (prot->cipher_type) {
	case TLS_CIPHER_AES_CCM_128:
509 510
		rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
		iv_offset = 1;
511 512 513 514 515
		break;
	case TLS_CIPHER_SM4_CCM:
		rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
		iv_offset = 1;
		break;
516 517 518 519
	}

	memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
	       prot->iv_size + prot->salt_size);
D
Dave Watson 已提交
520

521
	xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
522

523 524
	sge->offset += prot->prepend_size;
	sge->length -= prot->prepend_size;
D
Dave Watson 已提交
525

526
	msg_en->sg.curr = start;
527

D
Dave Watson 已提交
528
	aead_request_set_tfm(aead_req, ctx->aead_send);
529
	aead_request_set_ad(aead_req, prot->aad_size);
530 531
	aead_request_set_crypt(aead_req, rec->sg_aead_in,
			       rec->sg_aead_out,
532
			       data_len, rec->iv_data);
533 534

	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
535 536
				  tls_encrypt_done, sk);

537 538
	/* Add the record in tx_list */
	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
539
	atomic_inc(&ctx->encrypt_pending);
540

541 542 543
	rc = crypto_aead_encrypt(aead_req);
	if (!rc || rc != -EINPROGRESS) {
		atomic_dec(&ctx->encrypt_pending);
544 545
		sge->offset -= prot->prepend_size;
		sge->length += prot->prepend_size;
546
	}
D
Dave Watson 已提交
547

548 549 550 551
	if (!rc) {
		WRITE_ONCE(rec->tx_ready, true);
	} else if (rc != -EINPROGRESS) {
		list_del(&rec->list);
552
		return rc;
553
	}
D
Dave Watson 已提交
554

555 556
	/* Unhook the record from context if encryption is not failure */
	ctx->open_rec = NULL;
557
	tls_advance_record_sn(sk, prot, &tls_ctx->tx);
D
Dave Watson 已提交
558 559 560
	return rc;
}

561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
				 struct tls_rec **to, struct sk_msg *msg_opl,
				 struct sk_msg *msg_oen, u32 split_point,
				 u32 tx_overhead_size, u32 *orig_end)
{
	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
	struct scatterlist *sge, *osge, *nsge;
	u32 orig_size = msg_opl->sg.size;
	struct scatterlist tmp = { };
	struct sk_msg *msg_npl;
	struct tls_rec *new;
	int ret;

	new = tls_get_rec(sk);
	if (!new)
		return -ENOMEM;
	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
			   tx_overhead_size, 0);
	if (ret < 0) {
		tls_free_rec(sk, new);
		return ret;
	}

	*orig_end = msg_opl->sg.end;
	i = msg_opl->sg.start;
	sge = sk_msg_elem(msg_opl, i);
	while (apply && sge->length) {
		if (sge->length > apply) {
			u32 len = sge->length - apply;

			get_page(sg_page(sge));
			sg_set_page(&tmp, sg_page(sge), len,
				    sge->offset + apply);
			sge->length = apply;
			bytes += apply;
			apply = 0;
		} else {
			apply -= sge->length;
			bytes += sge->length;
		}

		sk_msg_iter_var_next(i);
		if (i == msg_opl->sg.end)
			break;
		sge = sk_msg_elem(msg_opl, i);
	}

	msg_opl->sg.end = i;
	msg_opl->sg.curr = i;
	msg_opl->sg.copybreak = 0;
	msg_opl->apply_bytes = 0;
	msg_opl->sg.size = bytes;

	msg_npl = &new->msg_plaintext;
	msg_npl->apply_bytes = apply;
	msg_npl->sg.size = orig_size - bytes;

	j = msg_npl->sg.start;
	nsge = sk_msg_elem(msg_npl, j);
	if (tmp.length) {
		memcpy(nsge, &tmp, sizeof(*nsge));
		sk_msg_iter_var_next(j);
		nsge = sk_msg_elem(msg_npl, j);
	}

	osge = sk_msg_elem(msg_opl, i);
	while (osge->length) {
		memcpy(nsge, osge, sizeof(*nsge));
		sg_unmark_end(nsge);
		sk_msg_iter_var_next(i);
		sk_msg_iter_var_next(j);
		if (i == *orig_end)
			break;
		osge = sk_msg_elem(msg_opl, i);
		nsge = sk_msg_elem(msg_npl, j);
	}

	msg_npl->sg.end = j;
	msg_npl->sg.curr = j;
	msg_npl->sg.copybreak = 0;

	*to = new;
	return 0;
}

static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
				  struct tls_rec *from, u32 orig_end)
{
	struct sk_msg *msg_npl = &from->msg_plaintext;
	struct sk_msg *msg_opl = &to->msg_plaintext;
	struct scatterlist *osge, *nsge;
	u32 i, j;

	i = msg_opl->sg.end;
	sk_msg_iter_var_prev(i);
	j = msg_npl->sg.start;

	osge = sk_msg_elem(msg_opl, i);
	nsge = sk_msg_elem(msg_npl, j);

	if (sg_page(osge) == sg_page(nsge) &&
	    osge->offset + osge->length == nsge->offset) {
		osge->length += nsge->length;
		put_page(sg_page(nsge));
	}

	msg_opl->sg.end = orig_end;
	msg_opl->sg.curr = orig_end;
	msg_opl->sg.copybreak = 0;
	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
	msg_opl->sg.size += msg_npl->sg.size;

	sk_msg_free(sk, &to->msg_encrypted);
	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);

	kfree(from);
}

D
Dave Watson 已提交
679 680 681 682
static int tls_push_record(struct sock *sk, int flags,
			   unsigned char record_type)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
683
	struct tls_prot_info *prot = &tls_ctx->prot_info;
B
Boris Pismenny 已提交
684
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
685
	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
686
	u32 i, split_point, orig_end;
687
	struct sk_msg *msg_pl, *msg_en;
688
	struct aead_request *req;
689
	bool split;
D
Dave Watson 已提交
690 691
	int rc;

692 693
	if (!rec)
		return 0;
694

695 696 697
	msg_pl = &rec->msg_plaintext;
	msg_en = &rec->msg_encrypted;

698 699
	split_point = msg_pl->apply_bytes;
	split = split_point && split_point < msg_pl->sg.size;
700 701 702 703 704 705 706 707 708
	if (unlikely((!split &&
		      msg_pl->sg.size +
		      prot->overhead_size > msg_en->sg.size) ||
		     (split &&
		      split_point +
		      prot->overhead_size > msg_en->sg.size))) {
		split = true;
		split_point = msg_en->sg.size;
	}
709 710
	if (split) {
		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
711
					   split_point, prot->overhead_size,
712 713 714
					   &orig_end);
		if (rc < 0)
			return rc;
715 716 717 718 719 720 721 722 723 724 725
		/* This can happen if above tls_split_open_record allocates
		 * a single large encryption buffer instead of two smaller
		 * ones. In this case adjust pointers and continue without
		 * split.
		 */
		if (!msg_pl->sg.size) {
			tls_merge_open_record(sk, rec, tmp, orig_end);
			msg_pl = &rec->msg_plaintext;
			msg_en = &rec->msg_encrypted;
			split = false;
		}
726
		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
727
			    prot->overhead_size);
728 729
	}

730 731
	rec->tx_flags = flags;
	req = &rec->aead_req;
D
Dave Watson 已提交
732

733 734
	i = msg_pl->sg.end;
	sk_msg_iter_var_prev(i);
D
Dave Watson 已提交
735 736

	rec->content_type = record_type;
737
	if (prot->version == TLS_1_3_VERSION) {
D
Dave Watson 已提交
738 739 740 741 742 743 744 745
		/* Add content type to end of message.  No padding added */
		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
		sg_mark_end(&rec->sg_content_type);
		sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
			 &rec->sg_content_type);
	} else {
		sg_mark_end(sk_msg_elem(msg_pl, i));
	}
746

747 748 749 750 751 752
	if (msg_pl->sg.end < msg_pl->sg.start) {
		sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
			 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
			 msg_pl->sg.data);
	}

753
	i = msg_pl->sg.start;
754
	sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
755 756 757 758 759 760 761 762

	i = msg_en->sg.end;
	sk_msg_iter_var_prev(i);
	sg_mark_end(sk_msg_elem(msg_en, i));

	i = msg_en->sg.start;
	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);

763
	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
764
		     tls_ctx->tx.rec_seq, record_type, prot);
D
Dave Watson 已提交
765 766

	tls_fill_prepend(tls_ctx,
767
			 page_address(sg_page(&msg_en->sg.data[i])) +
D
Dave Watson 已提交
768
			 msg_en->sg.data[i].offset,
769
			 msg_pl->sg.size + prot->tail_size,
770
			 record_type);
D
Dave Watson 已提交
771

772
	tls_ctx->pending_open_record_frags = false;
D
Dave Watson 已提交
773

D
Dave Watson 已提交
774
	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
775
			       msg_pl->sg.size + prot->tail_size, i);
776
	if (rc < 0) {
777
		if (rc != -EINPROGRESS) {
778
			tls_err_abort(sk, -EBADMSG);
779 780 781 782 783
			if (split) {
				tls_ctx->pending_open_record_frags = true;
				tls_merge_open_record(sk, rec, tmp, orig_end);
			}
		}
784
		ctx->async_capable = 1;
785
		return rc;
786 787 788
	} else if (split) {
		msg_pl = &tmp->msg_plaintext;
		msg_en = &tmp->msg_encrypted;
789
		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
790 791
		tls_ctx->pending_open_record_frags = true;
		ctx->open_rec = tmp;
792
	}
D
Dave Watson 已提交
793

794
	return tls_tx_records(sk, flags);
D
Dave Watson 已提交
795 796
}

797 798
static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
			       bool full_record, u8 record_type,
799
			       ssize_t *copied, int flags)
D
Dave Watson 已提交
800 801
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
802
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
803 804 805
	struct sk_msg msg_redir = { };
	struct sk_psock *psock;
	struct sock *sk_redir;
806
	struct tls_rec *rec;
807
	bool enospc, policy;
808
	int err = 0, send;
809
	u32 delta = 0;
810

811
	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
812
	psock = sk_psock_get(sk);
813 814
	if (!psock || !policy) {
		err = tls_push_record(sk, flags, record_type);
815
		if (err && sk->sk_err == EBADMSG) {
816 817
			*copied -= sk_msg_free(sk, msg);
			tls_free_open_rec(sk);
818
			err = -sk->sk_err;
819
		}
820 821
		if (psock)
			sk_psock_put(sk, psock);
822 823
		return err;
	}
824 825
more_data:
	enospc = sk_msg_full(msg);
826 827
	if (psock->eval == __SK_NONE) {
		delta = msg->sg.size;
828
		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
829
		delta -= msg->sg.size;
830
	}
831 832 833 834 835 836 837 838 839 840 841 842 843
	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
	    !enospc && !full_record) {
		err = -ENOSPC;
		goto out_err;
	}
	msg->cork_bytes = 0;
	send = msg->sg.size;
	if (msg->apply_bytes && msg->apply_bytes < send)
		send = msg->apply_bytes;

	switch (psock->eval) {
	case __SK_PASS:
		err = tls_push_record(sk, flags, record_type);
844
		if (err && sk->sk_err == EBADMSG) {
845 846
			*copied -= sk_msg_free(sk, msg);
			tls_free_open_rec(sk);
847
			err = -sk->sk_err;
848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878
			goto out_err;
		}
		break;
	case __SK_REDIRECT:
		sk_redir = psock->sk_redir;
		memcpy(&msg_redir, msg, sizeof(*msg));
		if (msg->apply_bytes < send)
			msg->apply_bytes = 0;
		else
			msg->apply_bytes -= send;
		sk_msg_return_zero(sk, msg, send);
		msg->sg.size -= send;
		release_sock(sk);
		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
		lock_sock(sk);
		if (err < 0) {
			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
			msg->sg.size = 0;
		}
		if (msg->sg.size == 0)
			tls_free_open_rec(sk);
		break;
	case __SK_DROP:
	default:
		sk_msg_free_partial(sk, msg, send);
		if (msg->apply_bytes < send)
			msg->apply_bytes = 0;
		else
			msg->apply_bytes -= send;
		if (msg->sg.size == 0)
			tls_free_open_rec(sk);
879
		*copied -= (send + delta);
880 881
		err = -EACCES;
	}
882

883 884
	if (likely(!err)) {
		bool reset_eval = !ctx->open_rec;
885

886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913
		rec = ctx->open_rec;
		if (rec) {
			msg = &rec->msg_plaintext;
			if (!msg->apply_bytes)
				reset_eval = true;
		}
		if (reset_eval) {
			psock->eval = __SK_NONE;
			if (psock->sk_redir) {
				sock_put(psock->sk_redir);
				psock->sk_redir = NULL;
			}
		}
		if (rec)
			goto more_data;
	}
 out_err:
	sk_psock_put(sk, psock);
	return err;
}

static int tls_sw_push_pending_record(struct sock *sk, int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec = ctx->open_rec;
	struct sk_msg *msg_pl;
	size_t copied;
914 915

	if (!rec)
916
		return 0;
917

918
	msg_pl = &rec->msg_plaintext;
919 920 921
	copied = msg_pl->sg.size;
	if (!copied)
		return 0;
922

923 924
	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
				   &copied, flags);
925 926 927 928
}

int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
D
Dave Watson 已提交
929
	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
930
	struct tls_context *tls_ctx = tls_get_ctx(sk);
931
	struct tls_prot_info *prot = &tls_ctx->prot_info;
932
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
933
	bool async_capable = ctx->async_capable;
934
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
D
David Howells 已提交
935
	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
D
Dave Watson 已提交
936
	bool eor = !(msg->msg_flags & MSG_MORE);
937 938
	size_t try_to_copy;
	ssize_t copied = 0;
939
	struct sk_msg *msg_pl, *msg_en;
940 941 942
	struct tls_rec *rec;
	int required_size;
	int num_async = 0;
D
Dave Watson 已提交
943
	bool full_record;
944 945
	int record_room;
	int num_zc = 0;
D
Dave Watson 已提交
946
	int orig_size;
947
	int ret = 0;
948
	int pending;
D
Dave Watson 已提交
949

950 951
	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
			       MSG_CMSG_COMPAT))
952
		return -EOPNOTSUPP;
D
Dave Watson 已提交
953

J
Jakub Kicinski 已提交
954
	mutex_lock(&tls_ctx->tx_lock);
D
Dave Watson 已提交
955 956 957 958
	lock_sock(sk);

	if (unlikely(msg->msg_controllen)) {
		ret = tls_proccess_cmsg(sk, msg, &record_type);
959 960 961 962 963 964
		if (ret) {
			if (ret == -EINPROGRESS)
				num_async++;
			else if (ret != -EAGAIN)
				goto send_end;
		}
D
Dave Watson 已提交
965 966 967 968
	}

	while (msg_data_left(msg)) {
		if (sk->sk_err) {
969
			ret = -sk->sk_err;
D
Dave Watson 已提交
970 971 972
			goto send_end;
		}

973 974 975 976
		if (ctx->open_rec)
			rec = ctx->open_rec;
		else
			rec = ctx->open_rec = tls_get_rec(sk);
977 978 979 980 981
		if (!rec) {
			ret = -ENOMEM;
			goto send_end;
		}

982 983 984 985
		msg_pl = &rec->msg_plaintext;
		msg_en = &rec->msg_encrypted;

		orig_size = msg_pl->sg.size;
D
Dave Watson 已提交
986 987
		full_record = false;
		try_to_copy = msg_data_left(msg);
988
		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
D
Dave Watson 已提交
989 990 991 992 993
		if (try_to_copy >= record_room) {
			try_to_copy = record_room;
			full_record = true;
		}

994
		required_size = msg_pl->sg.size + try_to_copy +
995
				prot->overhead_size;
D
Dave Watson 已提交
996 997 998

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
999

D
Dave Watson 已提交
1000
alloc_encrypted:
1001
		ret = tls_alloc_encrypted_msg(sk, required_size);
D
Dave Watson 已提交
1002 1003 1004 1005 1006 1007 1008 1009
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust try_to_copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
1010
			try_to_copy -= required_size - msg_en->sg.size;
D
Dave Watson 已提交
1011 1012
			full_record = true;
		}
1013 1014

		if (!is_kvec && (full_record || eor) && !async_capable) {
1015 1016
			u32 first = msg_pl->sg.end;

1017 1018
			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
							msg_pl, try_to_copy);
D
Dave Watson 已提交
1019 1020 1021
			if (ret)
				goto fallback_to_reg_send;

1022
			num_zc++;
D
Dave Watson 已提交
1023
			copied += try_to_copy;
1024 1025 1026 1027 1028

			sk_msg_sg_copy_set(msg_pl, first);
			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
						  record_type, &copied,
						  msg->msg_flags);
1029 1030 1031
			if (ret) {
				if (ret == -EINPROGRESS)
					num_async++;
1032 1033
				else if (ret == -ENOMEM)
					goto wait_for_memory;
1034
				else if (ctx->open_rec && ret == -ENOSPC)
1035
					goto rollback_iter;
1036 1037 1038
				else if (ret != -EAGAIN)
					goto send_end;
			}
1039
			continue;
1040 1041 1042 1043 1044
rollback_iter:
			copied -= try_to_copy;
			sk_msg_sg_copy_clear(msg_pl, first);
			iov_iter_revert(&msg->msg_iter,
					msg_pl->sg.size - orig_size);
D
Dave Watson 已提交
1045
fallback_to_reg_send:
1046
			sk_msg_trim(sk, msg_pl, orig_size);
D
Dave Watson 已提交
1047 1048
		}

1049
		required_size = msg_pl->sg.size + try_to_copy;
1050

1051
		ret = tls_clone_plaintext_msg(sk, required_size);
D
Dave Watson 已提交
1052 1053
		if (ret) {
			if (ret != -ENOSPC)
1054
				goto send_end;
D
Dave Watson 已提交
1055 1056 1057 1058 1059

			/* Adjust try_to_copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
1060
			try_to_copy -= required_size - msg_pl->sg.size;
D
Dave Watson 已提交
1061
			full_record = true;
1062 1063
			sk_msg_trim(sk, msg_en,
				    msg_pl->sg.size + prot->overhead_size);
D
Dave Watson 已提交
1064 1065
		}

1066 1067 1068 1069 1070 1071
		if (try_to_copy) {
			ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
						       msg_pl, try_to_copy);
			if (ret < 0)
				goto trim_sgl;
		}
D
Dave Watson 已提交
1072

1073 1074 1075 1076
		/* Open records defined only if successfully copied, otherwise
		 * we would trim the sg but not reset the open record frags.
		 */
		tls_ctx->pending_open_record_frags = true;
D
Dave Watson 已提交
1077 1078
		copied += try_to_copy;
		if (full_record || eor) {
1079 1080 1081
			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
						  record_type, &copied,
						  msg->msg_flags);
D
Dave Watson 已提交
1082
			if (ret) {
1083 1084
				if (ret == -EINPROGRESS)
					num_async++;
1085 1086 1087 1088 1089
				else if (ret == -ENOMEM)
					goto wait_for_memory;
				else if (ret != -EAGAIN) {
					if (ret == -ENOSPC)
						ret = 0;
1090
					goto send_end;
1091
				}
D
Dave Watson 已提交
1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
			}
		}

		continue;

wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
trim_sgl:
1103 1104
			if (ctx->open_rec)
				tls_trim_both_msgs(sk, orig_size);
D
Dave Watson 已提交
1105 1106 1107
			goto send_end;
		}

1108
		if (ctx->open_rec && msg_en->sg.size < required_size)
D
Dave Watson 已提交
1109 1110 1111
			goto alloc_encrypted;
	}

1112 1113 1114 1115
	if (!num_async) {
		goto send_end;
	} else if (num_zc) {
		/* Wait for pending encryptions to get completed */
1116 1117
		spin_lock_bh(&ctx->encrypt_compl_lock);
		ctx->async_notify = true;
1118

1119 1120 1121
		pending = atomic_read(&ctx->encrypt_pending);
		spin_unlock_bh(&ctx->encrypt_compl_lock);
		if (pending)
1122 1123 1124 1125
			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
		else
			reinit_completion(&ctx->async_wait.completion);

1126 1127 1128
		/* There can be no concurrent accesses, since we have no
		 * pending encrypt operations
		 */
1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142
		WRITE_ONCE(ctx->async_notify, false);

		if (ctx->async_wait.err) {
			ret = ctx->async_wait.err;
			copied = 0;
		}
	}

	/* Transmit if any encryptions have completed */
	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
		cancel_delayed_work(&ctx->tx_work.work);
		tls_tx_records(sk, msg->msg_flags);
	}

D
Dave Watson 已提交
1143 1144 1145 1146
send_end:
	ret = sk_stream_error(sk, msg->msg_flags, ret);

	release_sock(sk);
J
Jakub Kicinski 已提交
1147
	mutex_unlock(&tls_ctx->tx_lock);
1148
	return copied > 0 ? copied : ret;
D
Dave Watson 已提交
1149 1150
}

1151 1152
static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
			      int offset, size_t size, int flags)
D
Dave Watson 已提交
1153
{
1154
	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
D
Dave Watson 已提交
1155
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1156
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1157
	struct tls_prot_info *prot = &tls_ctx->prot_info;
D
Dave Watson 已提交
1158
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
1159
	struct sk_msg *msg_pl;
1160 1161
	struct tls_rec *rec;
	int num_async = 0;
1162
	ssize_t copied = 0;
D
Dave Watson 已提交
1163 1164
	bool full_record;
	int record_room;
1165
	int ret = 0;
1166
	bool eor;
D
Dave Watson 已提交
1167

1168
	eor = !(flags & MSG_SENDPAGE_NOTLAST);
D
Dave Watson 已提交
1169 1170 1171 1172 1173 1174 1175
	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);

	/* Call the sk_stream functions to manage the sndbuf mem. */
	while (size > 0) {
		size_t copy, required_size;

		if (sk->sk_err) {
1176
			ret = -sk->sk_err;
D
Dave Watson 已提交
1177 1178 1179
			goto sendpage_end;
		}

1180 1181 1182 1183
		if (ctx->open_rec)
			rec = ctx->open_rec;
		else
			rec = ctx->open_rec = tls_get_rec(sk);
1184 1185 1186 1187 1188
		if (!rec) {
			ret = -ENOMEM;
			goto sendpage_end;
		}

1189 1190
		msg_pl = &rec->msg_plaintext;

D
Dave Watson 已提交
1191
		full_record = false;
1192
		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
D
Dave Watson 已提交
1193 1194 1195 1196 1197
		copy = size;
		if (copy >= record_room) {
			copy = record_room;
			full_record = true;
		}
1198

1199
		required_size = msg_pl->sg.size + copy + prot->overhead_size;
D
Dave Watson 已提交
1200 1201 1202 1203

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
alloc_payload:
1204
		ret = tls_alloc_encrypted_msg(sk, required_size);
D
Dave Watson 已提交
1205 1206 1207 1208 1209 1210 1211 1212
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
1213
			copy -= required_size - msg_pl->sg.size;
D
Dave Watson 已提交
1214 1215 1216
			full_record = true;
		}

1217
		sk_msg_page_add(msg_pl, page, copy, offset);
D
Dave Watson 已提交
1218
		sk_mem_charge(sk, copy);
1219

D
Dave Watson 已提交
1220 1221
		offset += copy;
		size -= copy;
1222
		copied += copy;
D
Dave Watson 已提交
1223

1224 1225
		tls_ctx->pending_open_record_frags = true;
		if (full_record || eor || sk_msg_full(msg_pl)) {
1226 1227
			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
						  record_type, &copied, flags);
D
Dave Watson 已提交
1228
			if (ret) {
1229 1230
				if (ret == -EINPROGRESS)
					num_async++;
1231 1232 1233 1234 1235
				else if (ret == -ENOMEM)
					goto wait_for_memory;
				else if (ret != -EAGAIN) {
					if (ret == -ENOSPC)
						ret = 0;
1236
					goto sendpage_end;
1237
				}
D
Dave Watson 已提交
1238 1239 1240 1241 1242 1243 1244 1245
			}
		}
		continue;
wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
1246 1247
			if (ctx->open_rec)
				tls_trim_both_msgs(sk, msg_pl->sg.size);
D
Dave Watson 已提交
1248 1249 1250
			goto sendpage_end;
		}

1251 1252
		if (ctx->open_rec)
			goto alloc_payload;
D
Dave Watson 已提交
1253 1254
	}

1255 1256 1257 1258 1259 1260 1261
	if (num_async) {
		/* Transmit if any encryptions have completed */
		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
			cancel_delayed_work(&ctx->tx_work.work);
			tls_tx_records(sk, flags);
		}
	}
D
Dave Watson 已提交
1262
sendpage_end:
1263
	ret = sk_stream_error(sk, flags, ret);
1264
	return copied > 0 ? copied : ret;
D
Dave Watson 已提交
1265 1266
}

1267 1268 1269 1270 1271 1272
int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
			   int offset, size_t size, int flags)
{
	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
		      MSG_NO_SHARED_FRAGS))
1273
		return -EOPNOTSUPP;
1274 1275 1276 1277

	return tls_sw_do_sendpage(sk, page, offset, size, flags);
}

1278 1279 1280
int tls_sw_sendpage(struct sock *sk, struct page *page,
		    int offset, size_t size, int flags)
{
J
Jakub Kicinski 已提交
1281
	struct tls_context *tls_ctx = tls_get_ctx(sk);
1282 1283 1284 1285
	int ret;

	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1286
		return -EOPNOTSUPP;
1287

J
Jakub Kicinski 已提交
1288
	mutex_lock(&tls_ctx->tx_lock);
1289 1290 1291
	lock_sock(sk);
	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
	release_sock(sk);
J
Jakub Kicinski 已提交
1292
	mutex_unlock(&tls_ctx->tx_lock);
1293 1294 1295
	return ret;
}

1296
static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
1297
				     bool nonblock, long timeo, int *err)
D
Dave Watson 已提交
1298 1299
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1300
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1301 1302 1303
	struct sk_buff *skb;
	DEFINE_WAIT_FUNC(wait, woken_wake_function);

1304
	while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
D
Dave Watson 已提交
1305 1306 1307 1308 1309
		if (sk->sk_err) {
			*err = sock_error(sk);
			return NULL;
		}

1310 1311 1312 1313 1314 1315
		if (!skb_queue_empty(&sk->sk_receive_queue)) {
			__strp_unpause(&ctx->strp);
			if (ctx->recv_pkt)
				return ctx->recv_pkt;
		}

1316 1317 1318
		if (sk->sk_shutdown & RCV_SHUTDOWN)
			return NULL;

D
Dave Watson 已提交
1319 1320 1321
		if (sock_flag(sk, SOCK_DONE))
			return NULL;

1322
		if (nonblock || !timeo) {
D
Dave Watson 已提交
1323 1324 1325 1326 1327 1328
			*err = -EAGAIN;
			return NULL;
		}

		add_wait_queue(sk_sleep(sk), &wait);
		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1329 1330 1331 1332
		sk_wait_event(sk, &timeo,
			      ctx->recv_pkt != skb ||
			      !sk_psock_queue_empty(psock),
			      &wait);
D
Dave Watson 已提交
1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345
		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		remove_wait_queue(sk_sleep(sk), &wait);

		/* Handle signals */
		if (signal_pending(current)) {
			*err = sock_intr_errno(timeo);
			return NULL;
		}
	}

	return skb;
}

1346
static int tls_setup_from_iter(struct iov_iter *from,
1347 1348 1349 1350 1351 1352
			       int length, int *pages_used,
			       struct scatterlist *to,
			       int to_max_pages)
{
	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
	struct page *pages[MAX_SKB_FRAGS];
1353
	unsigned int size = 0;
1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395
	ssize_t copied, use;
	size_t offset;

	while (length > 0) {
		i = 0;
		maxpages = to_max_pages - num_elem;
		if (maxpages == 0) {
			rc = -EFAULT;
			goto out;
		}
		copied = iov_iter_get_pages(from, pages,
					    length,
					    maxpages, &offset);
		if (copied <= 0) {
			rc = -EFAULT;
			goto out;
		}

		iov_iter_advance(from, copied);

		length -= copied;
		size += copied;
		while (copied) {
			use = min_t(int, copied, PAGE_SIZE - offset);

			sg_set_page(&to[num_elem],
				    pages[i], use, offset);
			sg_unmark_end(&to[num_elem]);
			/* We do not uncharge memory from this API */

			offset = 0;
			copied -= use;

			i++;
			num_elem++;
		}
	}
	/* Mark the end in the last sg entry if newly added */
	if (num_elem > *pages_used)
		sg_mark_end(&to[num_elem - 1]);
out:
	if (rc)
1396
		iov_iter_revert(from, size);
1397 1398 1399 1400 1401
	*pages_used = num_elem;

	return rc;
}

1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412
/* This function decrypts the input skb into either out_iov or in out_sg
 * or in skb buffers itself. The input parameter 'zc' indicates if
 * zero-copy mode needs to be tried or not. With zero-copy mode, either
 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
 * NULL, then the decryption happens inside skb buffers itself, i.e.
 * zero-copy gets disabled and 'zc' is updated.
 */

static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			    struct iov_iter *out_iov,
			    struct scatterlist *out_sg,
1413
			    struct tls_decrypt_arg *darg)
1414 1415 1416
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1417
	struct tls_prot_info *prot = &tls_ctx->prot_info;
1418
	struct strp_msg *rxm = strp_msg(skb);
1419
	struct tls_msg *tlm = tls_msg(skb);
1420 1421 1422 1423 1424 1425
	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
	struct aead_request *aead_req;
	struct sk_buff *unused;
	u8 *aad, *iv, *mem = NULL;
	struct scatterlist *sgin = NULL;
	struct scatterlist *sgout = NULL;
1426 1427
	const int data_len = rxm->full_len - prot->overhead_size +
			     prot->tail_size;
1428
	int iv_offset = 0;
1429

1430
	if (darg->zc && (out_iov || out_sg)) {
1431
		if (out_iov)
1432 1433
			n_sgout = 1 +
				iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1434 1435
		else
			n_sgout = sg_nents(out_sg);
1436 1437
		n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
				 rxm->full_len - prot->prepend_size);
1438 1439
	} else {
		n_sgout = 0;
1440
		darg->zc = false;
1441
		n_sgin = skb_cow_data(skb, 0, &unused);
1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
	}

	if (n_sgin < 1)
		return -EBADMSG;

	/* Increment to accommodate AAD */
	n_sgin = n_sgin + 1;

	nsg = n_sgin + n_sgout;

	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
1454
	mem_size = mem_size + prot->aad_size;
1455
	mem_size = mem_size + MAX_IV_SIZE;
1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469

	/* Allocate a single block of memory which contains
	 * aead_req || sgin[] || sgout[] || aad || iv.
	 * This order achieves correct alignment for aead_req, sgin, sgout.
	 */
	mem = kmalloc(mem_size, sk->sk_allocation);
	if (!mem)
		return -ENOMEM;

	/* Segment the allocated memory */
	aead_req = (struct aead_request *)mem;
	sgin = (struct scatterlist *)(mem + aead_size);
	sgout = sgin + n_sgin;
	aad = (u8 *)(sgout + n_sgout);
1470
	iv = aad + prot->aad_size;
1471

1472 1473 1474 1475
	/* For CCM based ciphers, first byte of nonce+iv is a constant */
	switch (prot->cipher_type) {
	case TLS_CIPHER_AES_CCM_128:
		iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1476
		iv_offset = 1;
1477 1478 1479 1480 1481
		break;
	case TLS_CIPHER_SM4_CCM:
		iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
		iv_offset = 1;
		break;
1482 1483
	}

1484
	/* Prepare IV */
1485
	if (prot->version == TLS_1_3_VERSION ||
1486
	    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1487
		memcpy(iv + iv_offset, tls_ctx->rx.iv,
1488
		       prot->iv_size + prot->salt_size);
1489 1490 1491 1492 1493 1494 1495 1496
	} else {
		err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
				    iv + iv_offset + prot->salt_size,
				    prot->iv_size);
		if (err < 0) {
			kfree(mem);
			return err;
		}
1497
		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
1498
	}
1499
	xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
1500 1501

	/* Prepare AAD */
1502 1503
	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
		     prot->tail_size,
1504
		     tls_ctx->rx.rec_seq, tlm->control, prot);
1505 1506 1507

	/* Prepare sgin */
	sg_init_table(sgin, n_sgin);
1508
	sg_set_buf(&sgin[0], aad, prot->aad_size);
1509
	err = skb_to_sgvec(skb, &sgin[1],
1510 1511
			   rxm->offset + prot->prepend_size,
			   rxm->full_len - prot->prepend_size);
1512 1513 1514 1515 1516 1517 1518 1519
	if (err < 0) {
		kfree(mem);
		return err;
	}

	if (n_sgout) {
		if (out_iov) {
			sg_init_table(sgout, n_sgout);
1520
			sg_set_buf(&sgout[0], aad, prot->aad_size);
1521

1522 1523
			err = tls_setup_from_iter(out_iov, data_len,
						  &pages, &sgout[1],
1524
						  (n_sgout - 1));
1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535
			if (err < 0)
				goto fallback_to_reg_recv;
		} else if (out_sg) {
			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
		} else {
			goto fallback_to_reg_recv;
		}
	} else {
fallback_to_reg_recv:
		sgout = sgin;
		pages = 0;
1536
		darg->zc = false;
1537 1538 1539
	}

	/* Prepare and submit AEAD request */
1540
	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1541 1542 1543
				data_len, aead_req, darg);
	if (darg->async)
		return 0;
1544 1545 1546 1547 1548 1549 1550 1551 1552

	/* Release the pages in case iov was mapped to pages */
	for (; pages > 0; pages--)
		put_page(sg_page(&sgout[pages]));

	kfree(mem);
	return err;
}

1553
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1554 1555
			      struct iov_iter *dest,
			      struct tls_decrypt_arg *darg)
1556 1557
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
1558
	struct tls_prot_info *prot = &tls_ctx->prot_info;
1559
	struct strp_msg *rxm = strp_msg(skb);
1560
	struct tls_msg *tlm = tls_msg(skb);
1561
	int pad, err;
1562

1563
	if (tlm->decrypted) {
1564
		darg->zc = false;
1565 1566
		return 0;
	}
1567

1568 1569 1570 1571
	if (tls_ctx->rx_conf == TLS_HW) {
		err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
		if (err < 0)
			return err;
1572 1573
		if (err > 0) {
			tlm->decrypted = 1;
1574
			darg->zc = false;
1575
			goto decrypt_done;
1576
		}
1577
	}
D
Dave Watson 已提交
1578

1579
	err = decrypt_internal(sk, skb, dest, NULL, darg);
1580
	if (err < 0)
1581
		return err;
1582 1583
	if (darg->async)
		goto decrypt_next;
1584

1585 1586 1587 1588 1589 1590 1591 1592 1593
decrypt_done:
	pad = padding_length(prot, skb);
	if (pad < 0)
		return pad;

	rxm->full_len -= pad;
	rxm->offset += prot->prepend_size;
	rxm->full_len -= prot->overhead_size;
	tlm->decrypted = 1;
1594 1595
decrypt_next:
	tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1596 1597

	return 0;
1598 1599 1600 1601
}

int decrypt_skb(struct sock *sk, struct sk_buff *skb,
		struct scatterlist *sgout)
D
Dave Watson 已提交
1602
{
1603
	struct tls_decrypt_arg darg = { .zc = true, };
D
Dave Watson 已提交
1604

1605
	return decrypt_internal(sk, skb, NULL, sgout, &darg);
D
Dave Watson 已提交
1606 1607
}

1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
				   u8 *control)
{
	int err;

	if (!*control) {
		*control = tlm->control;
		if (!*control)
			return -EBADMSG;

		err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
			       sizeof(*control), control);
		if (*control != TLS_RECORD_TYPE_DATA) {
			if (err || msg->msg_flags & MSG_CTRUNC)
				return -EIO;
		}
	} else if (*control != tlm->control) {
		return 0;
	}

	return 1;
}

1631
/* This function traverses the rx_list in tls receive context to copies the
1632
 * decrypted records into the buffer provided by caller zero copy is not
1633 1634 1635 1636 1637
 * true. Further, the records are removed from the rx_list if it is not a peek
 * case and the record has been consumed completely.
 */
static int process_rx_list(struct tls_sw_context_rx *ctx,
			   struct msghdr *msg,
1638
			   u8 *control,
1639 1640 1641 1642 1643 1644
			   size_t skip,
			   size_t len,
			   bool zc,
			   bool is_peek)
{
	struct sk_buff *skb = skb_peek(&ctx->rx_list);
1645
	struct tls_msg *tlm;
1646
	ssize_t copied = 0;
1647
	int err;
1648

1649 1650
	while (skip && skb) {
		struct strp_msg *rxm = strp_msg(skb);
1651 1652
		tlm = tls_msg(skb);

1653 1654
		err = tls_record_content_type(msg, tlm, control);
		if (err <= 0)
1655
			goto out;
1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668

		if (skip < rxm->full_len)
			break;

		skip = skip - rxm->full_len;
		skb = skb_peek_next(skb, &ctx->rx_list);
	}

	while (len && skb) {
		struct sk_buff *next_skb;
		struct strp_msg *rxm = strp_msg(skb);
		int chunk = min_t(unsigned int, rxm->full_len - skip, len);

1669 1670
		tlm = tls_msg(skb);

1671 1672
		err = tls_record_content_type(msg, tlm, control);
		if (err <= 0)
1673
			goto out;
1674

1675
		if (!zc || (rxm->full_len - skip) > len) {
1676
			err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1677 1678
						    msg, chunk);
			if (err < 0)
1679
				goto out;
1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705
		}

		len = len - chunk;
		copied = copied + chunk;

		/* Consume the data from record if it is non-peek case*/
		if (!is_peek) {
			rxm->offset = rxm->offset + chunk;
			rxm->full_len = rxm->full_len - chunk;

			/* Return if there is unconsumed data in the record */
			if (rxm->full_len - skip)
				break;
		}

		/* The remaining skip-bytes must lie in 1st record in rx_list.
		 * So from the 2nd record, 'skip' should be 0.
		 */
		skip = 0;

		if (msg)
			msg->msg_flags |= MSG_EOR;

		next_skb = skb_peek_next(skb, &ctx->rx_list);

		if (!is_peek) {
1706
			__skb_unlink(skb, &ctx->rx_list);
1707
			consume_skb(skb);
1708 1709 1710 1711
		}

		skb = next_skb;
	}
1712
	err = 0;
1713

1714 1715
out:
	return copied ? : err;
1716 1717
}

D
Dave Watson 已提交
1718 1719 1720 1721 1722 1723 1724
int tls_sw_recvmsg(struct sock *sk,
		   struct msghdr *msg,
		   size_t len,
		   int flags,
		   int *addr_len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1725
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1726
	struct tls_prot_info *prot = &tls_ctx->prot_info;
1727
	struct sk_psock *psock;
1728 1729
	unsigned char control = 0;
	ssize_t decrypted = 0;
D
Dave Watson 已提交
1730
	struct strp_msg *rxm;
1731
	struct tls_msg *tlm;
D
Dave Watson 已提交
1732 1733
	struct sk_buff *skb;
	ssize_t copied = 0;
1734
	bool async = false;
1735
	int target, err = 0;
D
Dave Watson 已提交
1736
	long timeo;
D
David Howells 已提交
1737
	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1738
	bool is_peek = flags & MSG_PEEK;
1739
	bool bpf_strp_enabled;
1740
	bool zc_capable;
D
Dave Watson 已提交
1741 1742 1743 1744

	if (unlikely(flags & MSG_ERRQUEUE))
		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);

1745
	psock = sk_psock_get(sk);
D
Dave Watson 已提交
1746
	lock_sock(sk);
1747
	bpf_strp_enabled = sk_psock_strp_enabled(psock);
D
Dave Watson 已提交
1748

1749 1750 1751 1752 1753
	/* If crypto failed the connection is broken */
	err = ctx->async_wait.err;
	if (err)
		goto end;

1754
	/* Process pending decrypted records. It must be non-zero-copy */
1755
	err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
1756
	if (err < 0)
1757 1758
		goto end;

1759
	copied = err;
1760
	if (len <= copied)
1761
		goto end;
1762 1763 1764 1765

	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
	len = len - copied;
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1766

1767 1768
	zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
		     prot->version != TLS_1_3_VERSION;
1769
	decrypted = 0;
1770
	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
1771
		struct tls_decrypt_arg darg = {};
1772
		int to_decrypt, chunk;
D
Dave Watson 已提交
1773

1774
		skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
1775 1776
		if (!skb) {
			if (psock) {
1777 1778 1779 1780
				chunk = sk_msg_recvmsg(sk, psock, msg, len,
						       flags);
				if (chunk > 0)
					goto leave_on_list;
1781
			}
D
Dave Watson 已提交
1782
			goto recv_end;
1783
		}
D
Dave Watson 已提交
1784 1785

		rxm = strp_msg(skb);
1786
		tlm = tls_msg(skb);
1787

1788
		to_decrypt = rxm->full_len - prot->overhead_size;
1789

1790 1791
		if (zc_capable && to_decrypt <= len &&
		    tlm->control == TLS_RECORD_TYPE_DATA)
1792
			darg.zc = true;
1793

1794
		/* Do not use async mode if record is non-data */
1795
		if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1796
			darg.async = ctx->async_capable;
1797
		else
1798
			darg.async = false;
1799

1800
		err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
1801
		if (err < 0) {
1802
			tls_err_abort(sk, -EBADMSG);
1803 1804 1805
			goto recv_end;
		}

1806
		async |= darg.async;
1807 1808 1809 1810 1811 1812 1813 1814

		/* If the type of records being processed is not known yet,
		 * set it to record type just dequeued. If it is already known,
		 * but does not match the record type just dequeued, go to end.
		 * We always get record type here since for tls1.2, record type
		 * is known just after record is dequeued from stream parser.
		 * For tls1.3, we disable async.
		 */
1815 1816
		err = tls_record_content_type(msg, tlm, &control);
		if (err <= 0)
1817
			goto recv_end;
1818

1819 1820
		ctx->recv_pkt = NULL;
		__strp_unpause(&ctx->strp);
1821
		__skb_queue_tail(&ctx->rx_list, skb);
1822

1823 1824 1825
		if (async) {
			/* TLS 1.2-only, to_decrypt must be text length */
			chunk = min_t(int, to_decrypt, len);
1826 1827 1828 1829
leave_on_list:
			decrypted += chunk;
			len -= chunk;
			continue;
1830 1831 1832
		}
		/* TLS 1.3 may have updated the length by more than overhead */
		chunk = rxm->full_len;
1833

1834
		if (!darg.zc) {
1835 1836
			bool partially_consumed = chunk > len;

1837 1838 1839 1840 1841
			if (bpf_strp_enabled) {
				err = sk_psock_tls_strp_read(psock, skb);
				if (err != __SK_PASS) {
					rxm->offset = rxm->offset + rxm->full_len;
					rxm->full_len = 0;
1842
					__skb_unlink(skb, &ctx->rx_list);
1843 1844 1845 1846 1847 1848
					if (err == __SK_DROP)
						consume_skb(skb);
					continue;
				}
			}

1849
			if (partially_consumed)
1850
				chunk = len;
1851

1852 1853 1854 1855
			err = skb_copy_datagram_msg(skb, rxm->offset,
						    msg, chunk);
			if (err < 0)
				goto recv_end;
1856

1857 1858 1859 1860 1861 1862 1863
			if (is_peek)
				goto leave_on_list;

			if (partially_consumed) {
				rxm->offset += chunk;
				rxm->full_len -= chunk;
				goto leave_on_list;
1864
			}
D
Dave Watson 已提交
1865 1866
		}

1867
		decrypted += chunk;
D
Dave Watson 已提交
1868
		len -= chunk;
1869

1870
		__skb_unlink(skb, &ctx->rx_list);
1871
		consume_skb(skb);
1872

1873 1874 1875 1876 1877 1878
		/* Return full control message to userspace before trying
		 * to parse another message type
		 */
		msg->msg_flags |= MSG_EOR;
		if (control != TLS_RECORD_TYPE_DATA)
			break;
1879
	}
D
Dave Watson 已提交
1880 1881

recv_end:
1882
	if (async) {
1883
		int ret, pending;
1884

1885
		/* Wait for all previously submitted records to be decrypted */
1886
		spin_lock_bh(&ctx->decrypt_compl_lock);
J
Jakub Kicinski 已提交
1887
		reinit_completion(&ctx->async_wait.completion);
1888 1889 1890
		pending = atomic_read(&ctx->decrypt_pending);
		spin_unlock_bh(&ctx->decrypt_compl_lock);
		if (pending) {
1891 1892 1893 1894
			ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
			if (ret) {
				if (err >= 0 || err == -EINPROGRESS)
					err = ret;
1895 1896
				decrypted = 0;
				goto end;
1897 1898
			}
		}
1899

1900 1901
		/* Drain records from the rx_list & copy if required */
		if (is_peek || is_kvec)
1902
			err = process_rx_list(ctx, msg, &control, copied,
1903 1904
					      decrypted, false, is_peek);
		else
1905
			err = process_rx_list(ctx, msg, &control, 0,
1906
					      decrypted, true, is_peek);
1907
		decrypted = max(err, 0);
1908 1909
	}

1910 1911 1912
	copied += decrypted;

end:
D
Dave Watson 已提交
1913
	release_sock(sk);
1914 1915
	if (psock)
		sk_psock_put(sk, psock);
D
Dave Watson 已提交
1916 1917 1918 1919 1920 1921 1922 1923
	return copied ? : err;
}

ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
			   struct pipe_inode_info *pipe,
			   size_t len, unsigned int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
B
Boris Pismenny 已提交
1924
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1925 1926
	struct strp_msg *rxm = NULL;
	struct sock *sk = sock->sk;
1927
	struct tls_msg *tlm;
D
Dave Watson 已提交
1928 1929
	struct sk_buff *skb;
	ssize_t copied = 0;
1930
	bool from_queue;
D
Dave Watson 已提交
1931 1932 1933 1934 1935 1936
	int err = 0;
	long timeo;
	int chunk;

	lock_sock(sk);

1937
	timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
D
Dave Watson 已提交
1938

1939 1940 1941 1942
	from_queue = !skb_queue_empty(&ctx->rx_list);
	if (from_queue) {
		skb = __skb_dequeue(&ctx->rx_list);
	} else {
1943 1944
		struct tls_decrypt_arg darg = {};

1945 1946 1947 1948
		skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
				    &err);
		if (!skb)
			goto splice_read_end;
D
Dave Watson 已提交
1949

1950
		err = decrypt_skb_update(sk, skb, NULL, &darg);
1951 1952 1953 1954
		if (err < 0) {
			tls_err_abort(sk, -EBADMSG);
			goto splice_read_end;
		}
1955
	}
1956

1957 1958 1959
	rxm = strp_msg(skb);
	tlm = tls_msg(skb);

1960
	/* splice does not support reading control messages */
1961
	if (tlm->control != TLS_RECORD_TYPE_DATA) {
1962 1963
		err = -EINVAL;
		goto splice_read_end;
D
Dave Watson 已提交
1964
	}
1965

D
Dave Watson 已提交
1966 1967 1968 1969 1970
	chunk = min_t(unsigned int, rxm->full_len, len);
	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
	if (copied < 0)
		goto splice_read_end;

1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981
	if (!from_queue) {
		ctx->recv_pkt = NULL;
		__strp_unpause(&ctx->strp);
	}
	if (chunk < rxm->full_len) {
		__skb_queue_head(&ctx->rx_list, skb);
		rxm->offset += len;
		rxm->full_len -= len;
	} else {
		consume_skb(skb);
	}
D
Dave Watson 已提交
1982 1983 1984 1985 1986 1987

splice_read_end:
	release_sock(sk);
	return copied ? : err;
}

1988
bool tls_sw_sock_is_readable(struct sock *sk)
D
Dave Watson 已提交
1989 1990
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1991
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1992 1993
	bool ingress_empty = true;
	struct sk_psock *psock;
D
Dave Watson 已提交
1994

1995 1996 1997 1998 1999
	rcu_read_lock();
	psock = sk_psock(sk);
	if (psock)
		ingress_empty = list_empty(&psock->ingress_msg);
	rcu_read_unlock();
D
Dave Watson 已提交
2000

2001 2002
	return !ingress_empty || ctx->recv_pkt ||
		!skb_queue_empty(&ctx->rx_list);
D
Dave Watson 已提交
2003 2004 2005 2006 2007
}

static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2008
	struct tls_prot_info *prot = &tls_ctx->prot_info;
K
Kees Cook 已提交
2009
	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
D
Dave Watson 已提交
2010
	struct strp_msg *rxm = strp_msg(skb);
2011
	struct tls_msg *tlm = tls_msg(skb);
D
Dave Watson 已提交
2012 2013 2014 2015 2016
	size_t cipher_overhead;
	size_t data_len = 0;
	int ret;

	/* Verify that we have a full TLS header, or wait for more data */
2017
	if (rxm->offset + prot->prepend_size > skb->len)
D
Dave Watson 已提交
2018 2019
		return 0;

K
Kees Cook 已提交
2020
	/* Sanity-check size of on-stack buffer. */
2021
	if (WARN_ON(prot->prepend_size > sizeof(header))) {
K
Kees Cook 已提交
2022 2023 2024 2025
		ret = -EINVAL;
		goto read_failure;
	}

D
Dave Watson 已提交
2026
	/* Linearize header to local buffer */
2027
	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
D
Dave Watson 已提交
2028 2029 2030
	if (ret < 0)
		goto read_failure;

2031
	tlm->decrypted = 0;
2032
	tlm->control = header[0];
D
Dave Watson 已提交
2033 2034 2035

	data_len = ((header[4] & 0xFF) | (header[3] << 8));

2036
	cipher_overhead = prot->tag_size;
2037 2038
	if (prot->version != TLS_1_3_VERSION &&
	    prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2039
		cipher_overhead += prot->iv_size;
D
Dave Watson 已提交
2040

D
Dave Watson 已提交
2041
	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2042
	    prot->tail_size) {
D
Dave Watson 已提交
2043 2044 2045 2046 2047 2048 2049 2050
		ret = -EMSGSIZE;
		goto read_failure;
	}
	if (data_len < cipher_overhead) {
		ret = -EBADMSG;
		goto read_failure;
	}

D
Dave Watson 已提交
2051 2052 2053
	/* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
	if (header[1] != TLS_1_2_VERSION_MINOR ||
	    header[2] != TLS_1_2_VERSION_MAJOR) {
D
Dave Watson 已提交
2054 2055 2056
		ret = -EINVAL;
		goto read_failure;
	}
2057

2058
	tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2059
				     TCP_SKB_CB(skb)->seq + rxm->offset);
D
Dave Watson 已提交
2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070
	return data_len + TLS_HEADER_SIZE;

read_failure:
	tls_err_abort(strp->sk, ret);

	return ret;
}

static void tls_queue(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
B
Boris Pismenny 已提交
2071
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
2072 2073 2074 2075

	ctx->recv_pkt = skb;
	strp_pause(strp);

2076
	ctx->saved_data_ready(strp->sk);
D
Dave Watson 已提交
2077 2078 2079 2080 2081
}

static void tls_data_ready(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
2082
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2083
	struct sk_psock *psock;
D
Dave Watson 已提交
2084 2085

	strp_data_ready(&ctx->strp);
2086 2087

	psock = sk_psock_get(sk);
2088 2089 2090
	if (psock) {
		if (!list_empty(&psock->ingress_msg))
			ctx->saved_data_ready(sk);
2091 2092
		sk_psock_put(sk, psock);
	}
D
Dave Watson 已提交
2093 2094
}

2095 2096 2097 2098 2099 2100 2101 2102 2103
void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
{
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);

	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
	cancel_delayed_work_sync(&ctx->tx_work.work);
}

2104
void tls_sw_release_resources_tx(struct sock *sk)
D
Dave Watson 已提交
2105 2106
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
2107
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2108
	struct tls_rec *rec, *tmp;
2109
	int pending;
2110 2111

	/* Wait for any pending async encryptions to complete */
2112 2113 2114 2115 2116 2117
	spin_lock_bh(&ctx->encrypt_compl_lock);
	ctx->async_notify = true;
	pending = atomic_read(&ctx->encrypt_pending);
	spin_unlock_bh(&ctx->encrypt_compl_lock);

	if (pending)
2118 2119 2120 2121
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);

	tls_tx_records(sk, -1);

2122
	/* Free up un-sent records in tx_list. First, free
2123 2124
	 * the partially sent record if any at head of tx_list.
	 */
2125 2126
	if (tls_ctx->partially_sent_record) {
		tls_free_partial_record(sk, tls_ctx);
2127
		rec = list_first_entry(&ctx->tx_list,
2128 2129
				       struct tls_rec, list);
		list_del(&rec->list);
2130
		sk_msg_free(sk, &rec->msg_plaintext);
2131 2132 2133
		kfree(rec);
	}

2134
	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2135
		list_del(&rec->list);
2136 2137
		sk_msg_free(sk, &rec->msg_encrypted);
		sk_msg_free(sk, &rec->msg_plaintext);
2138 2139
		kfree(rec);
	}
D
Dave Watson 已提交
2140

2141
	crypto_free_aead(ctx->aead_send);
2142
	tls_free_open_rec(sk);
2143 2144 2145 2146 2147
}

void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
{
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
B
Boris Pismenny 已提交
2148 2149 2150 2151

	kfree(ctx);
}

2152
void tls_sw_release_resources_rx(struct sock *sk)
B
Boris Pismenny 已提交
2153 2154 2155 2156
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

2157 2158 2159
	kfree(tls_ctx->rx.rec_seq);
	kfree(tls_ctx->rx.iv);

D
Dave Watson 已提交
2160
	if (ctx->aead_recv) {
2161 2162
		kfree_skb(ctx->recv_pkt);
		ctx->recv_pkt = NULL;
2163
		__skb_queue_purge(&ctx->rx_list);
D
Dave Watson 已提交
2164 2165
		crypto_free_aead(ctx->aead_recv);
		strp_stop(&ctx->strp);
2166 2167 2168 2169 2170 2171 2172 2173 2174
		/* If tls_sw_strparser_arm() was not called (cleanup paths)
		 * we still want to strp_stop(), but sk->sk_data_ready was
		 * never swapped.
		 */
		if (ctx->saved_data_ready) {
			write_lock_bh(&sk->sk_callback_lock);
			sk->sk_data_ready = ctx->saved_data_ready;
			write_unlock_bh(&sk->sk_callback_lock);
		}
D
Dave Watson 已提交
2175
	}
2176 2177
}

2178
void tls_sw_strparser_done(struct tls_context *tls_ctx)
2179 2180 2181
{
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

2182 2183 2184 2185 2186 2187
	strp_done(&ctx->strp);
}

void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
{
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
2188 2189 2190 2191

	kfree(ctx);
}

2192 2193 2194 2195 2196 2197 2198 2199
void tls_sw_free_resources_rx(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);

	tls_sw_release_resources_rx(sk);
	tls_sw_free_ctx_rx(tls_ctx);
}

2200
/* The work handler to transmitt the encrypted records in tx_list */
2201 2202 2203 2204 2205 2206 2207
static void tx_work_handler(struct work_struct *work)
{
	struct delayed_work *delayed_work = to_delayed_work(work);
	struct tx_work *tx_work = container_of(delayed_work,
					       struct tx_work, work);
	struct sock *sk = tx_work->sk;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
2208
	struct tls_sw_context_tx *ctx;
2209

2210
	if (unlikely(!tls_ctx))
2211 2212
		return;

2213 2214 2215 2216 2217 2218
	ctx = tls_sw_ctx_tx(tls_ctx);
	if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
		return;

	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
		return;
J
Jakub Kicinski 已提交
2219
	mutex_lock(&tls_ctx->tx_lock);
2220 2221 2222
	lock_sock(sk);
	tls_tx_records(sk, -1);
	release_sock(sk);
J
Jakub Kicinski 已提交
2223
	mutex_unlock(&tls_ctx->tx_lock);
2224 2225
}

B
Boris Pismenny 已提交
2226 2227 2228 2229 2230
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
{
	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);

	/* Schedule the transmission if tx list is ready */
2231 2232 2233
	if (is_tx_ready(tx_ctx) &&
	    !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
		schedule_delayed_work(&tx_ctx->tx_work.work, 0);
B
Boris Pismenny 已提交
2234 2235
}

2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
{
	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);

	write_lock_bh(&sk->sk_callback_lock);
	rx_ctx->saved_data_ready = sk->sk_data_ready;
	sk->sk_data_ready = tls_data_ready;
	write_unlock_bh(&sk->sk_callback_lock);

	strp_check_rcv(&rx_ctx->strp);
}

D
Dave Watson 已提交
2248
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
D
Dave Watson 已提交
2249
{
2250 2251
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_prot_info *prot = &tls_ctx->prot_info;
D
Dave Watson 已提交
2252
	struct tls_crypto_info *crypto_info;
B
Boris Pismenny 已提交
2253 2254
	struct tls_sw_context_tx *sw_ctx_tx = NULL;
	struct tls_sw_context_rx *sw_ctx_rx = NULL;
D
Dave Watson 已提交
2255 2256 2257
	struct cipher_context *cctx;
	struct crypto_aead **aead;
	struct strp_callbacks cb;
2258
	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2259
	struct crypto_tfm *tfm;
2260
	char *iv, *rec_seq, *key, *salt, *cipher_name;
D
Dave Watson 已提交
2261
	size_t keysize;
D
Dave Watson 已提交
2262 2263 2264 2265 2266 2267 2268
	int rc = 0;

	if (!ctx) {
		rc = -EINVAL;
		goto out;
	}

B
Boris Pismenny 已提交
2269
	if (tx) {
2270 2271 2272 2273 2274 2275 2276 2277 2278 2279
		if (!ctx->priv_ctx_tx) {
			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
			if (!sw_ctx_tx) {
				rc = -ENOMEM;
				goto out;
			}
			ctx->priv_ctx_tx = sw_ctx_tx;
		} else {
			sw_ctx_tx =
				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
D
Dave Watson 已提交
2280 2281
		}
	} else {
2282 2283 2284 2285 2286 2287 2288 2289 2290 2291
		if (!ctx->priv_ctx_rx) {
			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
			if (!sw_ctx_rx) {
				rc = -ENOMEM;
				goto out;
			}
			ctx->priv_ctx_rx = sw_ctx_rx;
		} else {
			sw_ctx_rx =
				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
B
Boris Pismenny 已提交
2292
		}
D
Dave Watson 已提交
2293 2294
	}

D
Dave Watson 已提交
2295
	if (tx) {
2296
		crypto_init_wait(&sw_ctx_tx->async_wait);
2297
		spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2298
		crypto_info = &ctx->crypto_send.info;
D
Dave Watson 已提交
2299
		cctx = &ctx->tx;
B
Boris Pismenny 已提交
2300
		aead = &sw_ctx_tx->aead_send;
2301
		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2302 2303
		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
		sw_ctx_tx->tx_work.sk = sk;
D
Dave Watson 已提交
2304
	} else {
2305
		crypto_init_wait(&sw_ctx_rx->async_wait);
2306
		spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2307
		crypto_info = &ctx->crypto_recv.info;
D
Dave Watson 已提交
2308
		cctx = &ctx->rx;
2309
		skb_queue_head_init(&sw_ctx_rx->rx_list);
B
Boris Pismenny 已提交
2310
		aead = &sw_ctx_rx->aead_recv;
D
Dave Watson 已提交
2311 2312
	}

D
Dave Watson 已提交
2313 2314
	switch (crypto_info->cipher_type) {
	case TLS_CIPHER_AES_GCM_128: {
2315 2316 2317
		struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;

		gcm_128_info = (void *)crypto_info;
D
Dave Watson 已提交
2318 2319 2320
		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2321
		iv = gcm_128_info->iv;
D
Dave Watson 已提交
2322
		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2323
		rec_seq = gcm_128_info->rec_seq;
D
Dave Watson 已提交
2324 2325 2326
		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
		key = gcm_128_info->key;
		salt = gcm_128_info->salt;
2327 2328
		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
		cipher_name = "gcm(aes)";
D
Dave Watson 已提交
2329 2330 2331
		break;
	}
	case TLS_CIPHER_AES_GCM_256: {
2332 2333 2334
		struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;

		gcm_256_info = (void *)crypto_info;
D
Dave Watson 已提交
2335 2336 2337
		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2338
		iv = gcm_256_info->iv;
D
Dave Watson 已提交
2339
		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2340
		rec_seq = gcm_256_info->rec_seq;
D
Dave Watson 已提交
2341 2342 2343
		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
		key = gcm_256_info->key;
		salt = gcm_256_info->salt;
2344 2345 2346 2347 2348
		salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
		cipher_name = "gcm(aes)";
		break;
	}
	case TLS_CIPHER_AES_CCM_128: {
2349 2350 2351
		struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;

		ccm_128_info = (void *)crypto_info;
2352 2353 2354
		nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
		tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
		iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2355
		iv = ccm_128_info->iv;
2356
		rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2357
		rec_seq = ccm_128_info->rec_seq;
2358 2359 2360 2361 2362
		keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
		key = ccm_128_info->key;
		salt = ccm_128_info->salt;
		salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
		cipher_name = "ccm(aes)";
D
Dave Watson 已提交
2363 2364
		break;
	}
2365
	case TLS_CIPHER_CHACHA20_POLY1305: {
2366 2367
		struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;

2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381
		chacha20_poly1305_info = (void *)crypto_info;
		nonce_size = 0;
		tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
		iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
		iv = chacha20_poly1305_info->iv;
		rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
		rec_seq = chacha20_poly1305_info->rec_seq;
		keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
		key = chacha20_poly1305_info->key;
		salt = chacha20_poly1305_info->salt;
		salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
		cipher_name = "rfc7539(chacha20,poly1305)";
		break;
	}
2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415
	case TLS_CIPHER_SM4_GCM: {
		struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;

		sm4_gcm_info = (void *)crypto_info;
		nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
		tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
		iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
		iv = sm4_gcm_info->iv;
		rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
		rec_seq = sm4_gcm_info->rec_seq;
		keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
		key = sm4_gcm_info->key;
		salt = sm4_gcm_info->salt;
		salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
		cipher_name = "gcm(sm4)";
		break;
	}
	case TLS_CIPHER_SM4_CCM: {
		struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;

		sm4_ccm_info = (void *)crypto_info;
		nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
		tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
		iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
		iv = sm4_ccm_info->iv;
		rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
		rec_seq = sm4_ccm_info->rec_seq;
		keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
		key = sm4_ccm_info->key;
		salt = sm4_ccm_info->salt;
		salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
		cipher_name = "ccm(sm4)";
		break;
	}
D
Dave Watson 已提交
2416 2417
	default:
		rc = -EINVAL;
S
Sabrina Dubroca 已提交
2418
		goto free_priv;
D
Dave Watson 已提交
2419 2420
	}

2421 2422
	/* Sanity-check the sizes for stack allocations. */
	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2423
	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
K
Kees Cook 已提交
2424 2425 2426 2427
		rc = -EINVAL;
		goto free_priv;
	}

D
Dave Watson 已提交
2428 2429
	if (crypto_info->version == TLS_1_3_VERSION) {
		nonce_size = 0;
2430 2431
		prot->aad_size = TLS_HEADER_SIZE;
		prot->tail_size = 1;
D
Dave Watson 已提交
2432
	} else {
2433 2434
		prot->aad_size = TLS_AAD_SPACE_SIZE;
		prot->tail_size = 0;
D
Dave Watson 已提交
2435 2436
	}

2437 2438 2439 2440 2441 2442 2443
	prot->version = crypto_info->version;
	prot->cipher_type = crypto_info->cipher_type;
	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
	prot->tag_size = tag_size;
	prot->overhead_size = prot->prepend_size +
			      prot->tag_size + prot->tail_size;
	prot->iv_size = iv_size;
2444 2445
	prot->salt_size = salt_size;
	cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
D
Dave Watson 已提交
2446
	if (!cctx->iv) {
D
Dave Watson 已提交
2447
		rc = -ENOMEM;
S
Sabrina Dubroca 已提交
2448
		goto free_priv;
D
Dave Watson 已提交
2449
	}
D
Dave Watson 已提交
2450
	/* Note: 128 & 256 bit salt are the same size */
2451
	prot->rec_seq_size = rec_seq_size;
2452 2453
	memcpy(cctx->iv, salt, salt_size);
	memcpy(cctx->iv + salt_size, iv, iv_size);
2454
	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
D
Dave Watson 已提交
2455
	if (!cctx->rec_seq) {
D
Dave Watson 已提交
2456 2457 2458
		rc = -ENOMEM;
		goto free_iv;
	}
D
Dave Watson 已提交
2459 2460

	if (!*aead) {
2461
		*aead = crypto_alloc_aead(cipher_name, 0, 0);
D
Dave Watson 已提交
2462 2463 2464
		if (IS_ERR(*aead)) {
			rc = PTR_ERR(*aead);
			*aead = NULL;
D
Dave Watson 已提交
2465 2466 2467 2468 2469 2470
			goto free_rec_seq;
		}
	}

	ctx->push_pending_record = tls_sw_push_pending_record;

D
Dave Watson 已提交
2471 2472
	rc = crypto_aead_setkey(*aead, key, keysize);

D
Dave Watson 已提交
2473 2474 2475
	if (rc)
		goto free_aead;

2476
	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
D
Dave Watson 已提交
2477 2478 2479
	if (rc)
		goto free_aead;

B
Boris Pismenny 已提交
2480
	if (sw_ctx_rx) {
2481
		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2482 2483

		if (crypto_info->version == TLS_1_3_VERSION)
2484
			sw_ctx_rx->async_capable = 0;
2485 2486
		else
			sw_ctx_rx->async_capable =
2487 2488
				!!(tfm->__crt_alg->cra_flags &
				   CRYPTO_ALG_ASYNC);
2489

D
Dave Watson 已提交
2490 2491 2492 2493 2494
		/* Set up strparser */
		memset(&cb, 0, sizeof(cb));
		cb.rcv_msg = tls_queue;
		cb.parse_msg = tls_read_size;

B
Boris Pismenny 已提交
2495
		strp_init(&sw_ctx_rx->strp, sk, &cb);
D
Dave Watson 已提交
2496 2497 2498
	}

	goto out;
D
Dave Watson 已提交
2499 2500

free_aead:
D
Dave Watson 已提交
2501 2502
	crypto_free_aead(*aead);
	*aead = NULL;
D
Dave Watson 已提交
2503
free_rec_seq:
D
Dave Watson 已提交
2504 2505
	kfree(cctx->rec_seq);
	cctx->rec_seq = NULL;
D
Dave Watson 已提交
2506
free_iv:
B
Boris Pismenny 已提交
2507 2508
	kfree(cctx->iv);
	cctx->iv = NULL;
S
Sabrina Dubroca 已提交
2509
free_priv:
B
Boris Pismenny 已提交
2510 2511 2512 2513 2514 2515 2516
	if (tx) {
		kfree(ctx->priv_ctx_tx);
		ctx->priv_ctx_tx = NULL;
	} else {
		kfree(ctx->priv_ctx_rx);
		ctx->priv_ctx_rx = NULL;
	}
D
Dave Watson 已提交
2517 2518 2519
out:
	return rc;
}