tls_sw.c 44.2 KB
Newer Older
D
Dave Watson 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
/*
 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

D
Dave Watson 已提交
37
#include <linux/sched/signal.h>
D
Dave Watson 已提交
38 39 40
#include <linux/module.h>
#include <crypto/aead.h>

D
Dave Watson 已提交
41
#include <net/strparser.h>
D
Dave Watson 已提交
42 43
#include <net/tls.h>

K
Kees Cook 已提交
44 45
#define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
static int __skb_nsg(struct sk_buff *skb, int offset, int len,
                     unsigned int recursion_level)
{
        int start = skb_headlen(skb);
        int i, chunk = start - offset;
        struct sk_buff *frag_iter;
        int elt = 0;

        if (unlikely(recursion_level >= 24))
                return -EMSGSIZE;

        if (chunk > 0) {
                if (chunk > len)
                        chunk = len;
                elt++;
                len -= chunk;
                if (len == 0)
                        return elt;
                offset += chunk;
        }

        for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
                int end;

                WARN_ON(start > offset + len);

                end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
                chunk = end - offset;
                if (chunk > 0) {
                        if (chunk > len)
                                chunk = len;
                        elt++;
                        len -= chunk;
                        if (len == 0)
                                return elt;
                        offset += chunk;
                }
                start = end;
        }

        if (unlikely(skb_has_frag_list(skb))) {
                skb_walk_frags(skb, frag_iter) {
                        int end, ret;

                        WARN_ON(start > offset + len);

                        end = start + frag_iter->len;
                        chunk = end - offset;
                        if (chunk > 0) {
                                if (chunk > len)
                                        chunk = len;
                                ret = __skb_nsg(frag_iter, offset - start, chunk,
                                                recursion_level + 1);
                                if (unlikely(ret < 0))
                                        return ret;
                                elt += ret;
                                len -= chunk;
                                if (len == 0)
                                        return elt;
                                offset += chunk;
                        }
                        start = end;
                }
        }
        BUG_ON(len);
        return elt;
}

/* Return the number of scatterlist elements required to completely map the
 * skb, or -EMSGSIZE if the recursion depth is exceeded.
 */
static int skb_nsg(struct sk_buff *skb, int offset, int len)
{
        return __skb_nsg(skb, offset, len, 0);
}

122 123 124 125
static void tls_decrypt_done(struct crypto_async_request *req, int err)
{
	struct aead_request *aead_req = (struct aead_request *)req;
	struct scatterlist *sgout = aead_req->dst;
126 127
	struct tls_sw_context_rx *ctx;
	struct tls_context *tls_ctx;
128
	struct scatterlist *sg;
129
	struct sk_buff *skb;
130
	unsigned int pages;
131 132 133 134 135 136
	int pending;

	skb = (struct sk_buff *)req->data;
	tls_ctx = tls_get_ctx(skb->sk);
	ctx = tls_sw_ctx_rx(tls_ctx);
	pending = atomic_dec_return(&ctx->decrypt_pending);
137 138 139 140

	/* Propagate if there was an err */
	if (err) {
		ctx->async_wait.err = err;
141
		tls_err_abort(skb->sk, err);
142 143
	}

144 145 146 147 148
	/* After using skb->sk to propagate sk through crypto async callback
	 * we need to NULL it again.
	 */
	skb->sk = NULL;

149
	/* Release the skb, pages and memory allocated for crypto req */
150
	kfree_skb(skb);
151 152 153 154 155 156 157 158 159 160 161 162 163 164

	/* Skip the first S/G entry as it points to AAD */
	for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
		if (!sg)
			break;
		put_page(sg_page(sg));
	}

	kfree(aead_req);

	if (!pending && READ_ONCE(ctx->async_notify))
		complete(&ctx->async_wait.completion);
}

D
Dave Watson 已提交
165
static int tls_do_decryption(struct sock *sk,
166
			     struct sk_buff *skb,
D
Dave Watson 已提交
167 168 169 170
			     struct scatterlist *sgin,
			     struct scatterlist *sgout,
			     char *iv_recv,
			     size_t data_len,
171 172
			     struct aead_request *aead_req,
			     bool async)
D
Dave Watson 已提交
173 174
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
175
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
176 177
	int ret;

178
	aead_request_set_tfm(aead_req, ctx->aead_recv);
D
Dave Watson 已提交
179 180 181 182 183
	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
	aead_request_set_crypt(aead_req, sgin, sgout,
			       data_len + tls_ctx->rx.tag_size,
			       (u8 *)iv_recv);

184
	if (async) {
185 186 187 188 189 190 191
		/* Using skb->sk to push sk through to crypto async callback
		 * handler. This allows propagating errors up to the socket
		 * if needed. It _must_ be cleared in the async handler
		 * before kfree_skb is called. We _know_ skb->sk is NULL
		 * because it is a clone from strparser.
		 */
		skb->sk = sk;
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  tls_decrypt_done, skb);
		atomic_inc(&ctx->decrypt_pending);
	} else {
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  crypto_req_done, &ctx->async_wait);
	}

	ret = crypto_aead_decrypt(aead_req);
	if (ret == -EINPROGRESS) {
		if (async)
			return ret;

		ret = crypto_wait_req(ret, &ctx->async_wait);
	}

	if (async)
		atomic_dec(&ctx->decrypt_pending);

D
Dave Watson 已提交
213 214 215
	return ret;
}

D
Dave Watson 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
static void trim_sg(struct sock *sk, struct scatterlist *sg,
		    int *sg_num_elem, unsigned int *sg_size, int target_size)
{
	int i = *sg_num_elem - 1;
	int trim = *sg_size - target_size;

	if (trim <= 0) {
		WARN_ON(trim < 0);
		return;
	}

	*sg_size = target_size;
	while (trim >= sg[i].length) {
		trim -= sg[i].length;
		sk_mem_uncharge(sk, sg[i].length);
		put_page(sg_page(&sg[i]));
		i--;

		if (i < 0)
			goto out;
	}

	sg[i].length -= trim;
	sk_mem_uncharge(sk, trim);

out:
	*sg_num_elem = i + 1;
}

static void trim_both_sgl(struct sock *sk, int target_size)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
248
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
249
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
250

251 252 253
	trim_sg(sk, rec->sg_plaintext_data,
		&rec->sg_plaintext_num_elem,
		&rec->sg_plaintext_size,
D
Dave Watson 已提交
254 255 256
		target_size);

	if (target_size > 0)
257
		target_size += tls_ctx->tx.overhead_size;
D
Dave Watson 已提交
258

259 260 261
	trim_sg(sk, rec->sg_encrypted_data,
		&rec->sg_encrypted_num_elem,
		&rec->sg_encrypted_size,
D
Dave Watson 已提交
262 263 264 265 266 267
		target_size);
}

static int alloc_encrypted_sg(struct sock *sk, int len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
268
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
269
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
270 271
	int rc = 0;

272
	rc = sk_alloc_sg(sk, len,
273 274 275
			 rec->sg_encrypted_data, 0,
			 &rec->sg_encrypted_num_elem,
			 &rec->sg_encrypted_size, 0);
D
Dave Watson 已提交
276

277
	if (rc == -ENOSPC)
278
		rec->sg_encrypted_num_elem = ARRAY_SIZE(rec->sg_encrypted_data);
279

D
Dave Watson 已提交
280 281 282 283 284 285
	return rc;
}

static int alloc_plaintext_sg(struct sock *sk, int len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
286
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
287
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
288 289
	int rc = 0;

290 291
	rc = sk_alloc_sg(sk, len, rec->sg_plaintext_data, 0,
			 &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size,
292
			 tls_ctx->pending_open_record_frags);
D
Dave Watson 已提交
293

294
	if (rc == -ENOSPC)
295
		rec->sg_plaintext_num_elem = ARRAY_SIZE(rec->sg_plaintext_data);
296

D
Dave Watson 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
	return rc;
}

static void free_sg(struct sock *sk, struct scatterlist *sg,
		    int *sg_num_elem, unsigned int *sg_size)
{
	int i, n = *sg_num_elem;

	for (i = 0; i < n; ++i) {
		sk_mem_uncharge(sk, sg[i].length);
		put_page(sg_page(&sg[i]));
	}
	*sg_num_elem = 0;
	*sg_size = 0;
}

313
static void tls_free_open_rec(struct sock *sk)
D
Dave Watson 已提交
314 315
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
316
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
317
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
318

319 320 321 322 323 324 325 326 327 328 329
	/* Return if there is no open record */
	if (!rec)
		return;

	free_sg(sk, rec->sg_encrypted_data,
		&rec->sg_encrypted_num_elem,
		&rec->sg_encrypted_size);

	free_sg(sk, rec->sg_plaintext_data,
		&rec->sg_plaintext_num_elem,
		&rec->sg_plaintext_size);
330 331

	kfree(rec);
332 333 334 335 336 337 338 339 340 341
}

int tls_tx_records(struct sock *sk, int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec, *tmp;
	int tx_flags, rc = 0;

	if (tls_is_partially_sent_record(tls_ctx)) {
342
		rec = list_first_entry(&ctx->tx_list,
343 344 345 346 347 348 349 350 351 352 353 354
				       struct tls_rec, list);

		if (flags == -1)
			tx_flags = rec->tx_flags;
		else
			tx_flags = flags;

		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
		if (rc)
			goto tx_err;

		/* Full record has been transmitted.
355
		 * Remove the head of tx_list
356 357
		 */
		list_del(&rec->list);
358 359 360
		free_sg(sk, rec->sg_plaintext_data,
			&rec->sg_plaintext_num_elem, &rec->sg_plaintext_size);

361 362 363
		kfree(rec);
	}

364 365 366
	/* Tx all ready records */
	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
		if (READ_ONCE(rec->tx_ready)) {
367 368 369 370 371 372 373 374 375 376 377 378
			if (flags == -1)
				tx_flags = rec->tx_flags;
			else
				tx_flags = flags;

			rc = tls_push_sg(sk, tls_ctx,
					 &rec->sg_encrypted_data[0],
					 0, tx_flags);
			if (rc)
				goto tx_err;

			list_del(&rec->list);
379 380 381 382
			free_sg(sk, rec->sg_plaintext_data,
				&rec->sg_plaintext_num_elem,
				&rec->sg_plaintext_size);

383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
			kfree(rec);
		} else {
			break;
		}
	}

tx_err:
	if (rc < 0 && rc != -EAGAIN)
		tls_err_abort(sk, EBADMSG);

	return rc;
}

static void tls_encrypt_done(struct crypto_async_request *req, int err)
{
	struct aead_request *aead_req = (struct aead_request *)req;
	struct sock *sk = req->data;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec;
	bool ready = false;
	int pending;

	rec = container_of(aead_req, struct tls_rec, aead_req);

	rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
	rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;


	/* Free the record if error is previously set on socket */
	if (err || sk->sk_err) {
		free_sg(sk, rec->sg_encrypted_data,
			&rec->sg_encrypted_num_elem, &rec->sg_encrypted_size);
D
Dave Watson 已提交
416

417 418 419 420 421 422 423 424 425 426 427 428
		kfree(rec);
		rec = NULL;

		/* If err is already set on socket, return the same code */
		if (sk->sk_err) {
			ctx->async_wait.err = sk->sk_err;
		} else {
			ctx->async_wait.err = err;
			tls_err_abort(sk, err);
		}
	}

429 430 431 432 433 434 435 436 437 438 439 440
	if (rec) {
		struct tls_rec *first_rec;

		/* Mark the record as ready for transmission */
		smp_store_mb(rec->tx_ready, true);

		/* If received record is at head of tx_list, schedule tx */
		first_rec = list_first_entry(&ctx->tx_list,
					     struct tls_rec, list);
		if (rec == first_rec)
			ready = true;
	}
441 442 443 444 445 446 447 448 449 450 451 452

	pending = atomic_dec_return(&ctx->encrypt_pending);

	if (!pending && READ_ONCE(ctx->async_notify))
		complete(&ctx->async_wait.completion);

	if (!ready)
		return;

	/* Schedule the transmission */
	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
		schedule_delayed_work(&ctx->tx_work.work, 1);
D
Dave Watson 已提交
453 454
}

455 456
static int tls_do_encryption(struct sock *sk,
			     struct tls_context *tls_ctx,
457 458 459
			     struct tls_sw_context_tx *ctx,
			     struct aead_request *aead_req,
			     size_t data_len)
D
Dave Watson 已提交
460
{
461
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
462 463
	int rc;

464 465
	rec->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
	rec->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
D
Dave Watson 已提交
466 467 468

	aead_request_set_tfm(aead_req, ctx->aead_send);
	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
469 470
	aead_request_set_crypt(aead_req, rec->sg_aead_in,
			       rec->sg_aead_out,
471
			       data_len, tls_ctx->tx.iv);
472 473

	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
474 475
				  tls_encrypt_done, sk);

476 477
	/* Add the record in tx_list */
	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
478
	atomic_inc(&ctx->encrypt_pending);
479

480 481 482 483 484 485
	rc = crypto_aead_encrypt(aead_req);
	if (!rc || rc != -EINPROGRESS) {
		atomic_dec(&ctx->encrypt_pending);
		rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
		rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
	}
D
Dave Watson 已提交
486

487 488 489 490
	if (!rc) {
		WRITE_ONCE(rec->tx_ready, true);
	} else if (rc != -EINPROGRESS) {
		list_del(&rec->list);
491
		return rc;
492
	}
D
Dave Watson 已提交
493

494 495 496
	/* Unhook the record from context if encryption is not failure */
	ctx->open_rec = NULL;
	tls_advance_record_sn(sk, &tls_ctx->tx);
D
Dave Watson 已提交
497 498 499 500 501 502 503
	return rc;
}

static int tls_push_record(struct sock *sk, int flags,
			   unsigned char record_type)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
504
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
505
	struct tls_rec *rec = ctx->open_rec;
506
	struct aead_request *req;
D
Dave Watson 已提交
507 508
	int rc;

509 510
	if (!rec)
		return 0;
511

512 513
	rec->tx_flags = flags;
	req = &rec->aead_req;
D
Dave Watson 已提交
514

515 516 517 518
	sg_mark_end(rec->sg_plaintext_data + rec->sg_plaintext_num_elem - 1);
	sg_mark_end(rec->sg_encrypted_data + rec->sg_encrypted_num_elem - 1);

	tls_make_aad(rec->aad_space, rec->sg_plaintext_size,
519
		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
D
Dave Watson 已提交
520 521 522
		     record_type);

	tls_fill_prepend(tls_ctx,
523 524 525
			 page_address(sg_page(&rec->sg_encrypted_data[0])) +
			 rec->sg_encrypted_data[0].offset,
			 rec->sg_plaintext_size, record_type);
D
Dave Watson 已提交
526 527 528

	tls_ctx->pending_open_record_frags = 0;

529 530 531
	rc = tls_do_encryption(sk, tls_ctx, ctx, req, rec->sg_plaintext_size);
	if (rc == -EINPROGRESS)
		return -EINPROGRESS;
D
Dave Watson 已提交
532

533
	if (rc < 0) {
534
		tls_err_abort(sk, EBADMSG);
535 536
		return rc;
	}
D
Dave Watson 已提交
537

538
	return tls_tx_records(sk, flags);
D
Dave Watson 已提交
539 540 541 542 543 544 545 546
}

static int tls_sw_push_pending_record(struct sock *sk, int flags)
{
	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
}

static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
D
Dave Watson 已提交
547 548 549
			      int length, int *pages_used,
			      unsigned int *size_used,
			      struct scatterlist *to, int to_max_pages,
550
			      bool charge)
D
Dave Watson 已提交
551 552 553 554 555 556
{
	struct page *pages[MAX_SKB_FRAGS];

	size_t offset;
	ssize_t copied, use;
	int i = 0;
D
Dave Watson 已提交
557 558
	unsigned int size = *size_used;
	int num_elem = *pages_used;
D
Dave Watson 已提交
559 560 561 562 563
	int rc = 0;
	int maxpages;

	while (length > 0) {
		i = 0;
D
Dave Watson 已提交
564
		maxpages = to_max_pages - num_elem;
D
Dave Watson 已提交
565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
		if (maxpages == 0) {
			rc = -EFAULT;
			goto out;
		}
		copied = iov_iter_get_pages(from, pages,
					    length,
					    maxpages, &offset);
		if (copied <= 0) {
			rc = -EFAULT;
			goto out;
		}

		iov_iter_advance(from, copied);

		length -= copied;
		size += copied;
		while (copied) {
			use = min_t(int, copied, PAGE_SIZE - offset);

D
Dave Watson 已提交
584
			sg_set_page(&to[num_elem],
D
Dave Watson 已提交
585
				    pages[i], use, offset);
D
Dave Watson 已提交
586 587 588
			sg_unmark_end(&to[num_elem]);
			if (charge)
				sk_mem_charge(sk, use);
D
Dave Watson 已提交
589 590 591 592 593 594 595 596 597

			offset = 0;
			copied -= use;

			++i;
			++num_elem;
		}
	}

598 599 600
	/* Mark the end in the last sg entry if newly added */
	if (num_elem > *pages_used)
		sg_mark_end(&to[num_elem - 1]);
D
Dave Watson 已提交
601
out:
602 603
	if (rc)
		iov_iter_revert(from, size - *size_used);
D
Dave Watson 已提交
604 605 606
	*size_used = size;
	*pages_used = num_elem;

D
Dave Watson 已提交
607 608 609 610 611 612 613
	return rc;
}

static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
			     int bytes)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
614
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
615 616
	struct tls_rec *rec = ctx->open_rec;
	struct scatterlist *sg = rec->sg_plaintext_data;
D
Dave Watson 已提交
617 618 619
	int copy, i, rc = 0;

	for (i = tls_ctx->pending_open_record_frags;
620
	     i < rec->sg_plaintext_num_elem; ++i) {
D
Dave Watson 已提交
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
		copy = sg[i].length;
		if (copy_from_iter(
				page_address(sg_page(&sg[i])) + sg[i].offset,
				copy, from) != copy) {
			rc = -EFAULT;
			goto out;
		}
		bytes -= copy;

		++tls_ctx->pending_open_record_frags;

		if (!bytes)
			break;
	}

out:
	return rc;
}

640
static struct tls_rec *get_rec(struct sock *sk)
D
Dave Watson 已提交
641 642
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
643
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
	struct tls_rec *rec;
	int mem_size;

	/* Return if we already have an open record */
	if (ctx->open_rec)
		return ctx->open_rec;

	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);

	rec = kzalloc(mem_size, sk->sk_allocation);
	if (!rec)
		return NULL;

	sg_init_table(&rec->sg_plaintext_data[0],
		      ARRAY_SIZE(rec->sg_plaintext_data));
	sg_init_table(&rec->sg_encrypted_data[0],
		      ARRAY_SIZE(rec->sg_encrypted_data));

	sg_init_table(rec->sg_aead_in, 2);
	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
		   sizeof(rec->aad_space));
	sg_unmark_end(&rec->sg_aead_in[1]);
	sg_chain(rec->sg_aead_in, 2, rec->sg_plaintext_data);

	sg_init_table(rec->sg_aead_out, 2);
	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
		   sizeof(rec->aad_space));
	sg_unmark_end(&rec->sg_aead_out[1]);
	sg_chain(rec->sg_aead_out, 2, rec->sg_encrypted_data);

	ctx->open_rec = rec;

	return rec;
}

int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
D
Dave Watson 已提交
681
	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
682 683 684 685 686 687
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send);
	bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
D
Dave Watson 已提交
688 689
	bool eor = !(msg->msg_flags & MSG_MORE);
	size_t try_to_copy, copied = 0;
690 691 692
	struct tls_rec *rec;
	int required_size;
	int num_async = 0;
D
Dave Watson 已提交
693
	bool full_record;
694 695
	int record_room;
	int num_zc = 0;
D
Dave Watson 已提交
696
	int orig_size;
697
	int ret = 0;
D
Dave Watson 已提交
698 699 700 701 702 703

	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
		return -ENOTSUPP;

	lock_sock(sk);

704 705 706 707 708 709
	/* Wait till there is any pending write on socket */
	if (unlikely(sk->sk_write_pending)) {
		ret = wait_on_pending_writer(sk, &timeo);
		if (unlikely(ret))
			goto send_end;
	}
D
Dave Watson 已提交
710 711 712

	if (unlikely(msg->msg_controllen)) {
		ret = tls_proccess_cmsg(sk, msg, &record_type);
713 714 715 716 717 718
		if (ret) {
			if (ret == -EINPROGRESS)
				num_async++;
			else if (ret != -EAGAIN)
				goto send_end;
		}
D
Dave Watson 已提交
719 720 721 722
	}

	while (msg_data_left(msg)) {
		if (sk->sk_err) {
723
			ret = -sk->sk_err;
D
Dave Watson 已提交
724 725 726
			goto send_end;
		}

727 728 729 730 731 732 733
		rec = get_rec(sk);
		if (!rec) {
			ret = -ENOMEM;
			goto send_end;
		}

		orig_size = rec->sg_plaintext_size;
D
Dave Watson 已提交
734 735
		full_record = false;
		try_to_copy = msg_data_left(msg);
736
		record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
D
Dave Watson 已提交
737 738 739 740 741
		if (try_to_copy >= record_room) {
			try_to_copy = record_room;
			full_record = true;
		}

742
		required_size = rec->sg_plaintext_size + try_to_copy +
743
				tls_ctx->tx.overhead_size;
D
Dave Watson 已提交
744 745 746

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
747

D
Dave Watson 已提交
748 749 750 751 752 753 754 755 756 757
alloc_encrypted:
		ret = alloc_encrypted_sg(sk, required_size);
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust try_to_copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
758
			try_to_copy -= required_size - rec->sg_encrypted_size;
D
Dave Watson 已提交
759 760
			full_record = true;
		}
761 762

		if (!is_kvec && (full_record || eor) && !async_capable) {
D
Dave Watson 已提交
763
			ret = zerocopy_from_iter(sk, &msg->msg_iter,
764 765 766 767
				try_to_copy, &rec->sg_plaintext_num_elem,
				&rec->sg_plaintext_size,
				rec->sg_plaintext_data,
				ARRAY_SIZE(rec->sg_plaintext_data),
768
				true);
D
Dave Watson 已提交
769 770 771
			if (ret)
				goto fallback_to_reg_send;

772
			num_zc++;
D
Dave Watson 已提交
773 774
			copied += try_to_copy;
			ret = tls_push_record(sk, msg->msg_flags, record_type);
775 776 777 778 779 780
			if (ret) {
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret != -EAGAIN)
					goto send_end;
			}
781
			continue;
D
Dave Watson 已提交
782 783

fallback_to_reg_send:
784 785 786
			trim_sg(sk, rec->sg_plaintext_data,
				&rec->sg_plaintext_num_elem,
				&rec->sg_plaintext_size,
D
Dave Watson 已提交
787 788 789
				orig_size);
		}

790
		required_size = rec->sg_plaintext_size + try_to_copy;
D
Dave Watson 已提交
791 792 793 794 795 796 797 798 799 800
alloc_plaintext:
		ret = alloc_plaintext_sg(sk, required_size);
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust try_to_copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
801
			try_to_copy -= required_size - rec->sg_plaintext_size;
D
Dave Watson 已提交
802 803
			full_record = true;

804 805 806 807
			trim_sg(sk, rec->sg_encrypted_data,
				&rec->sg_encrypted_num_elem,
				&rec->sg_encrypted_size,
				rec->sg_plaintext_size +
808
				tls_ctx->tx.overhead_size);
D
Dave Watson 已提交
809 810 811 812 813 814 815 816 817 818
		}

		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
		if (ret)
			goto trim_sgl;

		copied += try_to_copy;
		if (full_record || eor) {
			ret = tls_push_record(sk, msg->msg_flags, record_type);
			if (ret) {
819 820 821 822
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret != -EAGAIN)
					goto send_end;
D
Dave Watson 已提交
823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
			}
		}

		continue;

wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
trim_sgl:
			trim_both_sgl(sk, orig_size);
			goto send_end;
		}

838
		if (rec->sg_encrypted_size < required_size)
D
Dave Watson 已提交
839 840 841 842 843
			goto alloc_encrypted;

		goto alloc_plaintext;
	}

844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
	if (!num_async) {
		goto send_end;
	} else if (num_zc) {
		/* Wait for pending encryptions to get completed */
		smp_store_mb(ctx->async_notify, true);

		if (atomic_read(&ctx->encrypt_pending))
			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
		else
			reinit_completion(&ctx->async_wait.completion);

		WRITE_ONCE(ctx->async_notify, false);

		if (ctx->async_wait.err) {
			ret = ctx->async_wait.err;
			copied = 0;
		}
	}

	/* Transmit if any encryptions have completed */
	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
		cancel_delayed_work(&ctx->tx_work.work);
		tls_tx_records(sk, msg->msg_flags);
	}

D
Dave Watson 已提交
869 870 871 872 873 874 875 876 877 878
send_end:
	ret = sk_stream_error(sk, msg->msg_flags, ret);

	release_sock(sk);
	return copied ? copied : ret;
}

int tls_sw_sendpage(struct sock *sk, struct page *page,
		    int offset, size_t size, int flags)
{
879
	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
D
Dave Watson 已提交
880
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
881
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
D
Dave Watson 已提交
882
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
883
	size_t orig_size = size;
D
Dave Watson 已提交
884
	struct scatterlist *sg;
885 886
	struct tls_rec *rec;
	int num_async = 0;
D
Dave Watson 已提交
887 888
	bool full_record;
	int record_room;
889
	int ret = 0;
890
	bool eor;
D
Dave Watson 已提交
891 892 893 894 895 896 897 898 899 900 901 902

	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
		      MSG_SENDPAGE_NOTLAST))
		return -ENOTSUPP;

	/* No MSG_EOR from splice, only look at MSG_MORE */
	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));

	lock_sock(sk);

	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);

903 904 905 906 907 908
	/* Wait till there is any pending write on socket */
	if (unlikely(sk->sk_write_pending)) {
		ret = wait_on_pending_writer(sk, &timeo);
		if (unlikely(ret))
			goto sendpage_end;
	}
D
Dave Watson 已提交
909 910 911 912 913 914

	/* Call the sk_stream functions to manage the sndbuf mem. */
	while (size > 0) {
		size_t copy, required_size;

		if (sk->sk_err) {
915
			ret = -sk->sk_err;
D
Dave Watson 已提交
916 917 918
			goto sendpage_end;
		}

919 920 921 922 923 924
		rec = get_rec(sk);
		if (!rec) {
			ret = -ENOMEM;
			goto sendpage_end;
		}

D
Dave Watson 已提交
925
		full_record = false;
926
		record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
D
Dave Watson 已提交
927 928 929 930 931
		copy = size;
		if (copy >= record_room) {
			copy = record_room;
			full_record = true;
		}
932
		required_size = rec->sg_plaintext_size + copy +
933
			      tls_ctx->tx.overhead_size;
D
Dave Watson 已提交
934 935 936 937 938 939 940 941 942 943 944 945 946

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
alloc_payload:
		ret = alloc_encrypted_sg(sk, required_size);
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
947
			copy -= required_size - rec->sg_plaintext_size;
D
Dave Watson 已提交
948 949 950 951
			full_record = true;
		}

		get_page(page);
952
		sg = rec->sg_plaintext_data + rec->sg_plaintext_num_elem;
D
Dave Watson 已提交
953
		sg_set_page(sg, page, copy, offset);
954 955
		sg_unmark_end(sg);

956
		rec->sg_plaintext_num_elem++;
D
Dave Watson 已提交
957 958 959 960

		sk_mem_charge(sk, copy);
		offset += copy;
		size -= copy;
961 962
		rec->sg_plaintext_size += copy;
		tls_ctx->pending_open_record_frags = rec->sg_plaintext_num_elem;
D
Dave Watson 已提交
963 964

		if (full_record || eor ||
965 966
		    rec->sg_plaintext_num_elem ==
		    ARRAY_SIZE(rec->sg_plaintext_data)) {
D
Dave Watson 已提交
967 968
			ret = tls_push_record(sk, flags, record_type);
			if (ret) {
969 970 971 972
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret != -EAGAIN)
					goto sendpage_end;
D
Dave Watson 已提交
973 974 975 976 977 978 979 980
			}
		}
		continue;
wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
981
			trim_both_sgl(sk, rec->sg_plaintext_size);
D
Dave Watson 已提交
982 983 984 985 986 987
			goto sendpage_end;
		}

		goto alloc_payload;
	}

988 989 990 991 992 993 994
	if (num_async) {
		/* Transmit if any encryptions have completed */
		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
			cancel_delayed_work(&ctx->tx_work.work);
			tls_tx_records(sk, flags);
		}
	}
D
Dave Watson 已提交
995 996 997 998 999 1000 1001 1002 1003 1004
sendpage_end:
	if (orig_size > size)
		ret = orig_size - size;
	else
		ret = sk_stream_error(sk, flags, ret);

	release_sock(sk);
	return ret;
}

D
Dave Watson 已提交
1005 1006 1007 1008
static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
				     long timeo, int *err)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1009
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1010 1011 1012 1013 1014 1015 1016 1017 1018
	struct sk_buff *skb;
	DEFINE_WAIT_FUNC(wait, woken_wake_function);

	while (!(skb = ctx->recv_pkt)) {
		if (sk->sk_err) {
			*err = sock_error(sk);
			return NULL;
		}

1019 1020 1021
		if (sk->sk_shutdown & RCV_SHUTDOWN)
			return NULL;

D
Dave Watson 已提交
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
		if (sock_flag(sk, SOCK_DONE))
			return NULL;

		if ((flags & MSG_DONTWAIT) || !timeo) {
			*err = -EAGAIN;
			return NULL;
		}

		add_wait_queue(sk_sleep(sk), &wait);
		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		remove_wait_queue(sk_sleep(sk), &wait);

		/* Handle signals */
		if (signal_pending(current)) {
			*err = sock_intr_errno(timeo);
			return NULL;
		}
	}

	return skb;
}

1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074
/* This function decrypts the input skb into either out_iov or in out_sg
 * or in skb buffers itself. The input parameter 'zc' indicates if
 * zero-copy mode needs to be tried or not. With zero-copy mode, either
 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
 * NULL, then the decryption happens inside skb buffers itself, i.e.
 * zero-copy gets disabled and 'zc' is updated.
 */

static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			    struct iov_iter *out_iov,
			    struct scatterlist *out_sg,
			    int *chunk, bool *zc)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct strp_msg *rxm = strp_msg(skb);
	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
	struct aead_request *aead_req;
	struct sk_buff *unused;
	u8 *aad, *iv, *mem = NULL;
	struct scatterlist *sgin = NULL;
	struct scatterlist *sgout = NULL;
	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;

	if (*zc && (out_iov || out_sg)) {
		if (out_iov)
			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
		else
			n_sgout = sg_nents(out_sg);
1075 1076
		n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
				 rxm->full_len - tls_ctx->rx.prepend_size);
1077 1078 1079
	} else {
		n_sgout = 0;
		*zc = false;
1080
		n_sgin = skb_cow_data(skb, 0, &unused);
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
	}

	if (n_sgin < 1)
		return -EBADMSG;

	/* Increment to accommodate AAD */
	n_sgin = n_sgin + 1;

	nsg = n_sgin + n_sgout;

	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);

	/* Allocate a single block of memory which contains
	 * aead_req || sgin[] || sgout[] || aad || iv.
	 * This order achieves correct alignment for aead_req, sgin, sgout.
	 */
	mem = kmalloc(mem_size, sk->sk_allocation);
	if (!mem)
		return -ENOMEM;

	/* Segment the allocated memory */
	aead_req = (struct aead_request *)mem;
	sgin = (struct scatterlist *)(mem + aead_size);
	sgout = sgin + n_sgin;
	aad = (u8 *)(sgout + n_sgout);
	iv = aad + TLS_AAD_SPACE_SIZE;

	/* Prepare IV */
	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
			    tls_ctx->rx.iv_size);
	if (err < 0) {
		kfree(mem);
		return err;
	}
	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);

	/* Prepare AAD */
	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
		     ctx->control);

	/* Prepare sgin */
	sg_init_table(sgin, n_sgin);
	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
	err = skb_to_sgvec(skb, &sgin[1],
			   rxm->offset + tls_ctx->rx.prepend_size,
			   rxm->full_len - tls_ctx->rx.prepend_size);
	if (err < 0) {
		kfree(mem);
		return err;
	}

	if (n_sgout) {
		if (out_iov) {
			sg_init_table(sgout, n_sgout);
			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);

			*chunk = 0;
			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
						 chunk, &sgout[1],
						 (n_sgout - 1), false);
			if (err < 0)
				goto fallback_to_reg_recv;
		} else if (out_sg) {
			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
		} else {
			goto fallback_to_reg_recv;
		}
	} else {
fallback_to_reg_recv:
		sgout = sgin;
		pages = 0;
		*chunk = 0;
		*zc = false;
	}

	/* Prepare and submit AEAD request */
1162 1163 1164 1165
	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
				data_len, aead_req, *zc);
	if (err == -EINPROGRESS)
		return err;
1166 1167 1168 1169 1170 1171 1172 1173 1174

	/* Release the pages in case iov was mapped to pages */
	for (; pages > 0; pages--)
		put_page(sg_page(&sgout[pages]));

	kfree(mem);
	return err;
}

1175
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1176
			      struct iov_iter *dest, int *chunk, bool *zc)
1177 1178 1179 1180 1181 1182
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct strp_msg *rxm = strp_msg(skb);
	int err = 0;

1183 1184
#ifdef CONFIG_TLS_DEVICE
	err = tls_device_decrypted(sk, skb);
1185 1186
	if (err < 0)
		return err;
1187 1188
#endif
	if (!ctx->decrypted) {
1189
		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
1190 1191 1192 1193
		if (err < 0) {
			if (err == -EINPROGRESS)
				tls_advance_record_sn(sk, &tls_ctx->rx);

1194
			return err;
1195
		}
1196 1197 1198
	} else {
		*zc = false;
	}
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210

	rxm->offset += tls_ctx->rx.prepend_size;
	rxm->full_len -= tls_ctx->rx.overhead_size;
	tls_advance_record_sn(sk, &tls_ctx->rx);
	ctx->decrypted = true;
	ctx->saved_data_ready(sk);

	return err;
}

int decrypt_skb(struct sock *sk, struct sk_buff *skb,
		struct scatterlist *sgout)
D
Dave Watson 已提交
1211
{
1212 1213
	bool zc = true;
	int chunk;
D
Dave Watson 已提交
1214

1215
	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
D
Dave Watson 已提交
1216 1217 1218 1219 1220 1221
}

static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
			       unsigned int len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1222
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1223

1224 1225
	if (skb) {
		struct strp_msg *rxm = strp_msg(skb);
D
Dave Watson 已提交
1226

1227 1228 1229 1230 1231 1232
		if (len < rxm->full_len) {
			rxm->offset += len;
			rxm->full_len -= len;
			return false;
		}
		kfree_skb(skb);
D
Dave Watson 已提交
1233 1234 1235 1236
	}

	/* Finished with message */
	ctx->recv_pkt = NULL;
1237
	__strp_unpause(&ctx->strp);
D
Dave Watson 已提交
1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249

	return true;
}

int tls_sw_recvmsg(struct sock *sk,
		   struct msghdr *msg,
		   size_t len,
		   int nonblock,
		   int flags,
		   int *addr_len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1250
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1251 1252 1253 1254 1255
	unsigned char control;
	struct strp_msg *rxm;
	struct sk_buff *skb;
	ssize_t copied = 0;
	bool cmsg = false;
1256
	int target, err = 0;
D
Dave Watson 已提交
1257
	long timeo;
1258
	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1259
	int num_async = 0;
D
Dave Watson 已提交
1260 1261 1262 1263 1264 1265 1266 1267

	flags |= nonblock;

	if (unlikely(flags & MSG_ERRQUEUE))
		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);

	lock_sock(sk);

1268
	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
D
Dave Watson 已提交
1269 1270 1271
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
	do {
		bool zc = false;
1272
		bool async = false;
D
Dave Watson 已提交
1273 1274 1275 1276 1277 1278 1279
		int chunk = 0;

		skb = tls_wait_data(sk, flags, timeo, &err);
		if (!skb)
			goto recv_end;

		rxm = strp_msg(skb);
1280

D
Dave Watson 已提交
1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298
		if (!cmsg) {
			int cerr;

			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
					sizeof(ctx->control), &ctx->control);
			cmsg = true;
			control = ctx->control;
			if (ctx->control != TLS_RECORD_TYPE_DATA) {
				if (cerr || msg->msg_flags & MSG_CTRUNC) {
					err = -EIO;
					goto recv_end;
				}
			}
		} else if (control != ctx->control) {
			goto recv_end;
		}

		if (!ctx->decrypted) {
1299
			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
D
Dave Watson 已提交
1300

1301 1302
			if (!is_kvec && to_copy <= len &&
			    likely(!(flags & MSG_PEEK)))
D
Dave Watson 已提交
1303
				zc = true;
1304 1305 1306

			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
						 &chunk, &zc);
1307
			if (err < 0 && err != -EINPROGRESS) {
1308 1309
				tls_err_abort(sk, EBADMSG);
				goto recv_end;
D
Dave Watson 已提交
1310
			}
1311 1312 1313 1314 1315 1316 1317

			if (err == -EINPROGRESS) {
				async = true;
				num_async++;
				goto pick_next_record;
			}

D
Dave Watson 已提交
1318 1319 1320 1321 1322
			ctx->decrypted = true;
		}

		if (!zc) {
			chunk = min_t(unsigned int, rxm->full_len, len);
1323

D
Dave Watson 已提交
1324 1325 1326 1327 1328 1329
			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
						    chunk);
			if (err < 0)
				goto recv_end;
		}

1330
pick_next_record:
D
Dave Watson 已提交
1331 1332 1333 1334 1335
		copied += chunk;
		len -= chunk;
		if (likely(!(flags & MSG_PEEK))) {
			u8 control = ctx->control;

1336 1337 1338 1339
			/* For async, drop current skb reference */
			if (async)
				skb = NULL;

D
Dave Watson 已提交
1340 1341 1342 1343 1344 1345 1346 1347
			if (tls_sw_advance_skb(sk, skb, chunk)) {
				/* Return full control message to
				 * userspace before trying to parse
				 * another message type
				 */
				msg->msg_flags |= MSG_EOR;
				if (control != TLS_RECORD_TYPE_DATA)
					goto recv_end;
1348 1349
			} else {
				break;
D
Dave Watson 已提交
1350
			}
1351 1352 1353 1354 1355 1356 1357
		} else {
			/* MSG_PEEK right now cannot look beyond current skb
			 * from strparser, meaning we cannot advance skb here
			 * and thus unpause strparser since we'd loose original
			 * one.
			 */
			break;
D
Dave Watson 已提交
1358
		}
1359

1360 1361 1362
		/* If we have a new message from strparser, continue now. */
		if (copied >= target && !ctx->recv_pkt)
			break;
D
Dave Watson 已提交
1363 1364 1365
	} while (len);

recv_end:
1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381
	if (num_async) {
		/* Wait for all previously submitted records to be decrypted */
		smp_store_mb(ctx->async_notify, true);
		if (atomic_read(&ctx->decrypt_pending)) {
			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
			if (err) {
				/* one of async decrypt failed */
				tls_err_abort(sk, err);
				copied = 0;
			}
		} else {
			reinit_completion(&ctx->async_wait.completion);
		}
		WRITE_ONCE(ctx->async_notify, false);
	}

D
Dave Watson 已提交
1382 1383 1384 1385 1386 1387 1388 1389 1390
	release_sock(sk);
	return copied ? : err;
}

ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
			   struct pipe_inode_info *pipe,
			   size_t len, unsigned int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
B
Boris Pismenny 已提交
1391
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1392 1393 1394 1395 1396 1397 1398
	struct strp_msg *rxm = NULL;
	struct sock *sk = sock->sk;
	struct sk_buff *skb;
	ssize_t copied = 0;
	int err = 0;
	long timeo;
	int chunk;
1399
	bool zc = false;
D
Dave Watson 已提交
1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415

	lock_sock(sk);

	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);

	skb = tls_wait_data(sk, flags, timeo, &err);
	if (!skb)
		goto splice_read_end;

	/* splice does not support reading control messages */
	if (ctx->control != TLS_RECORD_TYPE_DATA) {
		err = -ENOTSUPP;
		goto splice_read_end;
	}

	if (!ctx->decrypted) {
1416
		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
D
Dave Watson 已提交
1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438

		if (err < 0) {
			tls_err_abort(sk, EBADMSG);
			goto splice_read_end;
		}
		ctx->decrypted = true;
	}
	rxm = strp_msg(skb);

	chunk = min_t(unsigned int, rxm->full_len, len);
	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
	if (copied < 0)
		goto splice_read_end;

	if (likely(!(flags & MSG_PEEK)))
		tls_sw_advance_skb(sk, skb, copied);

splice_read_end:
	release_sock(sk);
	return copied ? : err;
}

1439 1440
unsigned int tls_sw_poll(struct file *file, struct socket *sock,
			 struct poll_table_struct *wait)
D
Dave Watson 已提交
1441
{
1442
	unsigned int ret;
D
Dave Watson 已提交
1443 1444
	struct sock *sk = sock->sk;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1445
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1446

1447 1448
	/* Grab POLLOUT and POLLHUP from the underlying socket */
	ret = ctx->sk_poll(file, sock, wait);
D
Dave Watson 已提交
1449

1450 1451
	/* Clear POLLIN bits, and set based on recv_pkt */
	ret &= ~(POLLIN | POLLRDNORM);
D
Dave Watson 已提交
1452
	if (ctx->recv_pkt)
1453
		ret |= POLLIN | POLLRDNORM;
D
Dave Watson 已提交
1454

1455
	return ret;
D
Dave Watson 已提交
1456 1457 1458 1459 1460
}

static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
B
Boris Pismenny 已提交
1461
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
K
Kees Cook 已提交
1462
	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
D
Dave Watson 已提交
1463 1464 1465 1466 1467 1468 1469 1470 1471
	struct strp_msg *rxm = strp_msg(skb);
	size_t cipher_overhead;
	size_t data_len = 0;
	int ret;

	/* Verify that we have a full TLS header, or wait for more data */
	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
		return 0;

K
Kees Cook 已提交
1472 1473 1474 1475 1476 1477
	/* Sanity-check size of on-stack buffer. */
	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
		ret = -EINVAL;
		goto read_failure;
	}

D
Dave Watson 已提交
1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498
	/* Linearize header to local buffer */
	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);

	if (ret < 0)
		goto read_failure;

	ctx->control = header[0];

	data_len = ((header[4] & 0xFF) | (header[3] << 8));

	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;

	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
		ret = -EMSGSIZE;
		goto read_failure;
	}
	if (data_len < cipher_overhead) {
		ret = -EBADMSG;
		goto read_failure;
	}

1499 1500
	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
D
Dave Watson 已提交
1501 1502 1503 1504
		ret = -EINVAL;
		goto read_failure;
	}

1505 1506 1507 1508
#ifdef CONFIG_TLS_DEVICE
	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
			     *(u64*)tls_ctx->rx.rec_seq);
#endif
D
Dave Watson 已提交
1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519
	return data_len + TLS_HEADER_SIZE;

read_failure:
	tls_err_abort(strp->sk, ret);

	return ret;
}

static void tls_queue(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
B
Boris Pismenny 已提交
1520
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1521 1522 1523 1524 1525 1526

	ctx->decrypted = false;

	ctx->recv_pkt = skb;
	strp_pause(strp);

1527
	ctx->saved_data_ready(strp->sk);
D
Dave Watson 已提交
1528 1529 1530 1531 1532
}

static void tls_data_ready(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1533
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1534 1535 1536 1537

	strp_data_ready(&ctx->strp);
}

B
Boris Pismenny 已提交
1538
void tls_sw_free_resources_tx(struct sock *sk)
D
Dave Watson 已提交
1539 1540
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1541
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553
	struct tls_rec *rec, *tmp;

	/* Wait for any pending async encryptions to complete */
	smp_store_mb(ctx->async_notify, true);
	if (atomic_read(&ctx->encrypt_pending))
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);

	cancel_delayed_work_sync(&ctx->tx_work.work);

	/* Tx whatever records we can transmit and abandon the rest */
	tls_tx_records(sk, -1);

1554
	/* Free up un-sent records in tx_list. First, free
1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570
	 * the partially sent record if any at head of tx_list.
	 */
	if (tls_ctx->partially_sent_record) {
		struct scatterlist *sg = tls_ctx->partially_sent_record;

		while (1) {
			put_page(sg_page(sg));
			sk_mem_uncharge(sk, sg->length);

			if (sg_is_last(sg))
				break;
			sg++;
		}

		tls_ctx->partially_sent_record = NULL;

1571
		rec = list_first_entry(&ctx->tx_list,
1572
				       struct tls_rec, list);
1573 1574 1575 1576 1577

		free_sg(sk, rec->sg_plaintext_data,
			&rec->sg_plaintext_num_elem,
			&rec->sg_plaintext_size);

1578 1579 1580 1581
		list_del(&rec->list);
		kfree(rec);
	}

1582
	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
1583 1584 1585 1586
		free_sg(sk, rec->sg_encrypted_data,
			&rec->sg_encrypted_num_elem,
			&rec->sg_encrypted_size);

1587 1588 1589 1590
		free_sg(sk, rec->sg_plaintext_data,
			&rec->sg_plaintext_num_elem,
			&rec->sg_plaintext_size);

1591 1592 1593
		list_del(&rec->list);
		kfree(rec);
	}
D
Dave Watson 已提交
1594

1595
	crypto_free_aead(ctx->aead_send);
1596
	tls_free_open_rec(sk);
B
Boris Pismenny 已提交
1597 1598 1599 1600

	kfree(ctx);
}

1601
void tls_sw_release_resources_rx(struct sock *sk)
B
Boris Pismenny 已提交
1602 1603 1604 1605
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

D
Dave Watson 已提交
1606
	if (ctx->aead_recv) {
1607 1608
		kfree_skb(ctx->recv_pkt);
		ctx->recv_pkt = NULL;
D
Dave Watson 已提交
1609 1610 1611 1612 1613 1614 1615 1616 1617
		crypto_free_aead(ctx->aead_recv);
		strp_stop(&ctx->strp);
		write_lock_bh(&sk->sk_callback_lock);
		sk->sk_data_ready = ctx->saved_data_ready;
		write_unlock_bh(&sk->sk_callback_lock);
		release_sock(sk);
		strp_done(&ctx->strp);
		lock_sock(sk);
	}
1618 1619 1620 1621 1622 1623 1624 1625
}

void tls_sw_free_resources_rx(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

	tls_sw_release_resources_rx(sk);
D
Dave Watson 已提交
1626 1627 1628 1629

	kfree(ctx);
}

1630
/* The work handler to transmitt the encrypted records in tx_list */
1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647
static void tx_work_handler(struct work_struct *work)
{
	struct delayed_work *delayed_work = to_delayed_work(work);
	struct tx_work *tx_work = container_of(delayed_work,
					       struct tx_work, work);
	struct sock *sk = tx_work->sk;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);

	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
		return;

	lock_sock(sk);
	tls_tx_records(sk, -1);
	release_sock(sk);
}

D
Dave Watson 已提交
1648
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
D
Dave Watson 已提交
1649 1650 1651
{
	struct tls_crypto_info *crypto_info;
	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
B
Boris Pismenny 已提交
1652 1653
	struct tls_sw_context_tx *sw_ctx_tx = NULL;
	struct tls_sw_context_rx *sw_ctx_rx = NULL;
D
Dave Watson 已提交
1654 1655 1656
	struct cipher_context *cctx;
	struct crypto_aead **aead;
	struct strp_callbacks cb;
D
Dave Watson 已提交
1657 1658 1659 1660 1661 1662 1663 1664 1665
	u16 nonce_size, tag_size, iv_size, rec_seq_size;
	char *iv, *rec_seq;
	int rc = 0;

	if (!ctx) {
		rc = -EINVAL;
		goto out;
	}

B
Boris Pismenny 已提交
1666
	if (tx) {
1667 1668 1669 1670 1671 1672 1673 1674 1675 1676
		if (!ctx->priv_ctx_tx) {
			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
			if (!sw_ctx_tx) {
				rc = -ENOMEM;
				goto out;
			}
			ctx->priv_ctx_tx = sw_ctx_tx;
		} else {
			sw_ctx_tx =
				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
D
Dave Watson 已提交
1677 1678
		}
	} else {
1679 1680 1681 1682 1683 1684 1685 1686 1687 1688
		if (!ctx->priv_ctx_rx) {
			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
			if (!sw_ctx_rx) {
				rc = -ENOMEM;
				goto out;
			}
			ctx->priv_ctx_rx = sw_ctx_rx;
		} else {
			sw_ctx_rx =
				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
B
Boris Pismenny 已提交
1689
		}
D
Dave Watson 已提交
1690 1691
	}

D
Dave Watson 已提交
1692
	if (tx) {
1693
		crypto_init_wait(&sw_ctx_tx->async_wait);
1694
		crypto_info = &ctx->crypto_send.info;
D
Dave Watson 已提交
1695
		cctx = &ctx->tx;
B
Boris Pismenny 已提交
1696
		aead = &sw_ctx_tx->aead_send;
1697
		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
1698 1699
		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
		sw_ctx_tx->tx_work.sk = sk;
D
Dave Watson 已提交
1700
	} else {
1701
		crypto_init_wait(&sw_ctx_rx->async_wait);
1702
		crypto_info = &ctx->crypto_recv.info;
D
Dave Watson 已提交
1703
		cctx = &ctx->rx;
B
Boris Pismenny 已提交
1704
		aead = &sw_ctx_rx->aead_recv;
D
Dave Watson 已提交
1705 1706
	}

D
Dave Watson 已提交
1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721
	switch (crypto_info->cipher_type) {
	case TLS_CIPHER_AES_GCM_128: {
		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
		rec_seq =
		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
		gcm_128_info =
			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
		break;
	}
	default:
		rc = -EINVAL;
S
Sabrina Dubroca 已提交
1722
		goto free_priv;
D
Dave Watson 已提交
1723 1724
	}

K
Kees Cook 已提交
1725
	/* Sanity-check the IV size for stack allocations. */
K
Kees Cook 已提交
1726
	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
K
Kees Cook 已提交
1727 1728 1729 1730
		rc = -EINVAL;
		goto free_priv;
	}

D
Dave Watson 已提交
1731 1732 1733 1734 1735 1736 1737
	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
	cctx->tag_size = tag_size;
	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
	cctx->iv_size = iv_size;
	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
			   GFP_KERNEL);
	if (!cctx->iv) {
D
Dave Watson 已提交
1738
		rc = -ENOMEM;
S
Sabrina Dubroca 已提交
1739
		goto free_priv;
D
Dave Watson 已提交
1740
	}
D
Dave Watson 已提交
1741 1742 1743
	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
	cctx->rec_seq_size = rec_seq_size;
1744
	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
D
Dave Watson 已提交
1745
	if (!cctx->rec_seq) {
D
Dave Watson 已提交
1746 1747 1748
		rc = -ENOMEM;
		goto free_iv;
	}
D
Dave Watson 已提交
1749 1750 1751 1752 1753 1754

	if (!*aead) {
		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
		if (IS_ERR(*aead)) {
			rc = PTR_ERR(*aead);
			*aead = NULL;
D
Dave Watson 已提交
1755 1756 1757 1758 1759 1760
			goto free_rec_seq;
		}
	}

	ctx->push_pending_record = tls_sw_push_pending_record;

1761
	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
D
Dave Watson 已提交
1762 1763 1764 1765
				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
	if (rc)
		goto free_aead;

D
Dave Watson 已提交
1766 1767 1768 1769
	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
	if (rc)
		goto free_aead;

B
Boris Pismenny 已提交
1770
	if (sw_ctx_rx) {
D
Dave Watson 已提交
1771 1772 1773 1774 1775
		/* Set up strparser */
		memset(&cb, 0, sizeof(cb));
		cb.rcv_msg = tls_queue;
		cb.parse_msg = tls_read_size;

B
Boris Pismenny 已提交
1776
		strp_init(&sw_ctx_rx->strp, sk, &cb);
D
Dave Watson 已提交
1777 1778

		write_lock_bh(&sk->sk_callback_lock);
B
Boris Pismenny 已提交
1779
		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
D
Dave Watson 已提交
1780 1781 1782
		sk->sk_data_ready = tls_data_ready;
		write_unlock_bh(&sk->sk_callback_lock);

1783
		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
D
Dave Watson 已提交
1784

B
Boris Pismenny 已提交
1785
		strp_check_rcv(&sw_ctx_rx->strp);
D
Dave Watson 已提交
1786 1787 1788
	}

	goto out;
D
Dave Watson 已提交
1789 1790

free_aead:
D
Dave Watson 已提交
1791 1792
	crypto_free_aead(*aead);
	*aead = NULL;
D
Dave Watson 已提交
1793
free_rec_seq:
D
Dave Watson 已提交
1794 1795
	kfree(cctx->rec_seq);
	cctx->rec_seq = NULL;
D
Dave Watson 已提交
1796
free_iv:
B
Boris Pismenny 已提交
1797 1798
	kfree(cctx->iv);
	cctx->iv = NULL;
S
Sabrina Dubroca 已提交
1799
free_priv:
B
Boris Pismenny 已提交
1800 1801 1802 1803 1804 1805 1806
	if (tx) {
		kfree(ctx->priv_ctx_tx);
		ctx->priv_ctx_tx = NULL;
	} else {
		kfree(ctx->priv_ctx_rx);
		ctx->priv_ctx_rx = NULL;
	}
D
Dave Watson 已提交
1807 1808 1809
out:
	return rc;
}