tls_sw.c 45.4 KB
Newer Older
D
Dave Watson 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
/*
 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

D
Dave Watson 已提交
37
#include <linux/sched/signal.h>
D
Dave Watson 已提交
38 39 40
#include <linux/module.h>
#include <crypto/aead.h>

D
Dave Watson 已提交
41
#include <net/strparser.h>
D
Dave Watson 已提交
42 43
#include <net/tls.h>

K
Kees Cook 已提交
44 45
#define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
static int __skb_nsg(struct sk_buff *skb, int offset, int len,
                     unsigned int recursion_level)
{
        int start = skb_headlen(skb);
        int i, chunk = start - offset;
        struct sk_buff *frag_iter;
        int elt = 0;

        if (unlikely(recursion_level >= 24))
                return -EMSGSIZE;

        if (chunk > 0) {
                if (chunk > len)
                        chunk = len;
                elt++;
                len -= chunk;
                if (len == 0)
                        return elt;
                offset += chunk;
        }

        for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
                int end;

                WARN_ON(start > offset + len);

                end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
                chunk = end - offset;
                if (chunk > 0) {
                        if (chunk > len)
                                chunk = len;
                        elt++;
                        len -= chunk;
                        if (len == 0)
                                return elt;
                        offset += chunk;
                }
                start = end;
        }

        if (unlikely(skb_has_frag_list(skb))) {
                skb_walk_frags(skb, frag_iter) {
                        int end, ret;

                        WARN_ON(start > offset + len);

                        end = start + frag_iter->len;
                        chunk = end - offset;
                        if (chunk > 0) {
                                if (chunk > len)
                                        chunk = len;
                                ret = __skb_nsg(frag_iter, offset - start, chunk,
                                                recursion_level + 1);
                                if (unlikely(ret < 0))
                                        return ret;
                                elt += ret;
                                len -= chunk;
                                if (len == 0)
                                        return elt;
                                offset += chunk;
                        }
                        start = end;
                }
        }
        BUG_ON(len);
        return elt;
}

/* Return the number of scatterlist elements required to completely map the
 * skb, or -EMSGSIZE if the recursion depth is exceeded.
 */
static int skb_nsg(struct sk_buff *skb, int offset, int len)
{
        return __skb_nsg(skb, offset, len, 0);
}

122 123 124 125
static void tls_decrypt_done(struct crypto_async_request *req, int err)
{
	struct aead_request *aead_req = (struct aead_request *)req;
	struct scatterlist *sgout = aead_req->dst;
126 127
	struct tls_sw_context_rx *ctx;
	struct tls_context *tls_ctx;
128
	struct scatterlist *sg;
129
	struct sk_buff *skb;
130
	unsigned int pages;
131 132 133 134 135 136
	int pending;

	skb = (struct sk_buff *)req->data;
	tls_ctx = tls_get_ctx(skb->sk);
	ctx = tls_sw_ctx_rx(tls_ctx);
	pending = atomic_dec_return(&ctx->decrypt_pending);
137 138 139 140

	/* Propagate if there was an err */
	if (err) {
		ctx->async_wait.err = err;
141
		tls_err_abort(skb->sk, err);
142 143
	}

144 145 146 147 148
	/* After using skb->sk to propagate sk through crypto async callback
	 * we need to NULL it again.
	 */
	skb->sk = NULL;

149
	/* Release the skb, pages and memory allocated for crypto req */
150
	kfree_skb(skb);
151 152 153 154 155 156 157 158 159 160 161 162 163 164

	/* Skip the first S/G entry as it points to AAD */
	for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
		if (!sg)
			break;
		put_page(sg_page(sg));
	}

	kfree(aead_req);

	if (!pending && READ_ONCE(ctx->async_notify))
		complete(&ctx->async_wait.completion);
}

D
Dave Watson 已提交
165
static int tls_do_decryption(struct sock *sk,
166
			     struct sk_buff *skb,
D
Dave Watson 已提交
167 168 169 170
			     struct scatterlist *sgin,
			     struct scatterlist *sgout,
			     char *iv_recv,
			     size_t data_len,
171 172
			     struct aead_request *aead_req,
			     bool async)
D
Dave Watson 已提交
173 174
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
175
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
176 177
	int ret;

178
	aead_request_set_tfm(aead_req, ctx->aead_recv);
D
Dave Watson 已提交
179 180 181 182 183
	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
	aead_request_set_crypt(aead_req, sgin, sgout,
			       data_len + tls_ctx->rx.tag_size,
			       (u8 *)iv_recv);

184
	if (async) {
185 186 187 188 189 190 191
		/* Using skb->sk to push sk through to crypto async callback
		 * handler. This allows propagating errors up to the socket
		 * if needed. It _must_ be cleared in the async handler
		 * before kfree_skb is called. We _know_ skb->sk is NULL
		 * because it is a clone from strparser.
		 */
		skb->sk = sk;
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  tls_decrypt_done, skb);
		atomic_inc(&ctx->decrypt_pending);
	} else {
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  crypto_req_done, &ctx->async_wait);
	}

	ret = crypto_aead_decrypt(aead_req);
	if (ret == -EINPROGRESS) {
		if (async)
			return ret;

		ret = crypto_wait_req(ret, &ctx->async_wait);
	}

	if (async)
		atomic_dec(&ctx->decrypt_pending);

D
Dave Watson 已提交
213 214 215
	return ret;
}

D
Dave Watson 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
static void trim_sg(struct sock *sk, struct scatterlist *sg,
		    int *sg_num_elem, unsigned int *sg_size, int target_size)
{
	int i = *sg_num_elem - 1;
	int trim = *sg_size - target_size;

	if (trim <= 0) {
		WARN_ON(trim < 0);
		return;
	}

	*sg_size = target_size;
	while (trim >= sg[i].length) {
		trim -= sg[i].length;
		sk_mem_uncharge(sk, sg[i].length);
		put_page(sg_page(&sg[i]));
		i--;

		if (i < 0)
			goto out;
	}

	sg[i].length -= trim;
	sk_mem_uncharge(sk, trim);

out:
	*sg_num_elem = i + 1;
}

static void trim_both_sgl(struct sock *sk, int target_size)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
248
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
249
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
250

251
	trim_sg(sk, &rec->sg_plaintext_data[1],
252 253
		&rec->sg_plaintext_num_elem,
		&rec->sg_plaintext_size,
D
Dave Watson 已提交
254 255 256
		target_size);

	if (target_size > 0)
257
		target_size += tls_ctx->tx.overhead_size;
D
Dave Watson 已提交
258

259
	trim_sg(sk, &rec->sg_encrypted_data[1],
260 261
		&rec->sg_encrypted_num_elem,
		&rec->sg_encrypted_size,
D
Dave Watson 已提交
262 263 264 265 266 267
		target_size);
}

static int alloc_encrypted_sg(struct sock *sk, int len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
268
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
269
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
270 271
	int rc = 0;

272
	rc = sk_alloc_sg(sk, len,
273
			 &rec->sg_encrypted_data[1], 0,
274 275
			 &rec->sg_encrypted_num_elem,
			 &rec->sg_encrypted_size, 0);
D
Dave Watson 已提交
276

277
	if (rc == -ENOSPC)
278 279
		rec->sg_encrypted_num_elem =
			ARRAY_SIZE(rec->sg_encrypted_data) - 1;
280

D
Dave Watson 已提交
281 282 283
	return rc;
}

284
static int move_to_plaintext_sg(struct sock *sk, int required_size)
D
Dave Watson 已提交
285 286
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
287
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
288
	struct tls_rec *rec = ctx->open_rec;
289 290 291 292
	struct scatterlist *plain_sg = &rec->sg_plaintext_data[1];
	struct scatterlist *enc_sg = &rec->sg_encrypted_data[1];
	int enc_sg_idx = 0;
	int skip, len;
D
Dave Watson 已提交
293

294 295
	if (rec->sg_plaintext_num_elem == MAX_SKB_FRAGS)
		return -ENOSPC;
D
Dave Watson 已提交
296

297 298 299 300 301
	/* We add page references worth len bytes from enc_sg at the
	 * end of plain_sg. It is guaranteed that sg_encrypted_data
	 * has enough required room (ensured by caller).
	 */
	len = required_size - rec->sg_plaintext_size;
302

303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
	/* Skip initial bytes in sg_encrypted_data to be able
	 * to use same offset of both plain and encrypted data.
	 */
	skip = tls_ctx->tx.prepend_size + rec->sg_plaintext_size;

	while (enc_sg_idx < rec->sg_encrypted_num_elem) {
		if (enc_sg[enc_sg_idx].length > skip)
			break;

		skip -= enc_sg[enc_sg_idx].length;
		enc_sg_idx++;
	}

	/* unmark the end of plain_sg*/
	sg_unmark_end(plain_sg + rec->sg_plaintext_num_elem - 1);

	while (len) {
		struct page *page = sg_page(&enc_sg[enc_sg_idx]);
		int bytes = enc_sg[enc_sg_idx].length - skip;
		int offset = enc_sg[enc_sg_idx].offset + skip;

		if (bytes > len)
			bytes = len;
		else
			enc_sg_idx++;

		/* Skipping is required only one time */
		skip = 0;

		/* Increment page reference */
		get_page(page);

		sg_set_page(&plain_sg[rec->sg_plaintext_num_elem], page,
			    bytes, offset);

		sk_mem_charge(sk, bytes);

		len -= bytes;
		rec->sg_plaintext_size += bytes;

		rec->sg_plaintext_num_elem++;

		if (rec->sg_plaintext_num_elem == MAX_SKB_FRAGS)
			return -ENOSPC;
	}

	return 0;
D
Dave Watson 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
}

static void free_sg(struct sock *sk, struct scatterlist *sg,
		    int *sg_num_elem, unsigned int *sg_size)
{
	int i, n = *sg_num_elem;

	for (i = 0; i < n; ++i) {
		sk_mem_uncharge(sk, sg[i].length);
		put_page(sg_page(&sg[i]));
	}
	*sg_num_elem = 0;
	*sg_size = 0;
}

365
static void tls_free_open_rec(struct sock *sk)
D
Dave Watson 已提交
366 367
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
368
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
369
	struct tls_rec *rec = ctx->open_rec;
D
Dave Watson 已提交
370

371 372 373 374
	/* Return if there is no open record */
	if (!rec)
		return;

375
	free_sg(sk, &rec->sg_encrypted_data[1],
376 377 378
		&rec->sg_encrypted_num_elem,
		&rec->sg_encrypted_size);

379
	free_sg(sk, &rec->sg_plaintext_data[1],
380 381
		&rec->sg_plaintext_num_elem,
		&rec->sg_plaintext_size);
382 383

	kfree(rec);
384 385 386 387 388 389 390 391 392 393
}

int tls_tx_records(struct sock *sk, int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec, *tmp;
	int tx_flags, rc = 0;

	if (tls_is_partially_sent_record(tls_ctx)) {
394
		rec = list_first_entry(&ctx->tx_list,
395 396 397 398 399 400 401 402 403 404 405 406
				       struct tls_rec, list);

		if (flags == -1)
			tx_flags = rec->tx_flags;
		else
			tx_flags = flags;

		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
		if (rc)
			goto tx_err;

		/* Full record has been transmitted.
407
		 * Remove the head of tx_list
408 409
		 */
		list_del(&rec->list);
410
		free_sg(sk, &rec->sg_plaintext_data[1],
411 412
			&rec->sg_plaintext_num_elem, &rec->sg_plaintext_size);

413 414 415
		kfree(rec);
	}

416 417 418
	/* Tx all ready records */
	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
		if (READ_ONCE(rec->tx_ready)) {
419 420 421 422 423 424
			if (flags == -1)
				tx_flags = rec->tx_flags;
			else
				tx_flags = flags;

			rc = tls_push_sg(sk, tls_ctx,
425
					 &rec->sg_encrypted_data[1],
426 427 428 429 430
					 0, tx_flags);
			if (rc)
				goto tx_err;

			list_del(&rec->list);
431
			free_sg(sk, &rec->sg_plaintext_data[1],
432 433 434
				&rec->sg_plaintext_num_elem,
				&rec->sg_plaintext_size);

435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
			kfree(rec);
		} else {
			break;
		}
	}

tx_err:
	if (rc < 0 && rc != -EAGAIN)
		tls_err_abort(sk, EBADMSG);

	return rc;
}

static void tls_encrypt_done(struct crypto_async_request *req, int err)
{
	struct aead_request *aead_req = (struct aead_request *)req;
	struct sock *sk = req->data;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec;
	bool ready = false;
	int pending;

	rec = container_of(aead_req, struct tls_rec, aead_req);

460 461
	rec->sg_encrypted_data[1].offset -= tls_ctx->tx.prepend_size;
	rec->sg_encrypted_data[1].length += tls_ctx->tx.prepend_size;
462 463


464
	/* Check if error is previously set on socket */
465 466 467 468 469 470 471 472 473 474 475 476
	if (err || sk->sk_err) {
		rec = NULL;

		/* If err is already set on socket, return the same code */
		if (sk->sk_err) {
			ctx->async_wait.err = sk->sk_err;
		} else {
			ctx->async_wait.err = err;
			tls_err_abort(sk, err);
		}
	}

477 478 479 480 481 482 483 484 485 486 487 488
	if (rec) {
		struct tls_rec *first_rec;

		/* Mark the record as ready for transmission */
		smp_store_mb(rec->tx_ready, true);

		/* If received record is at head of tx_list, schedule tx */
		first_rec = list_first_entry(&ctx->tx_list,
					     struct tls_rec, list);
		if (rec == first_rec)
			ready = true;
	}
489 490 491 492 493 494 495 496 497 498 499

	pending = atomic_dec_return(&ctx->encrypt_pending);

	if (!pending && READ_ONCE(ctx->async_notify))
		complete(&ctx->async_wait.completion);

	if (!ready)
		return;

	/* Schedule the transmission */
	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
500
		schedule_delayed_work(&ctx->tx_work.work, 2);
D
Dave Watson 已提交
501 502
}

503 504
static int tls_do_encryption(struct sock *sk,
			     struct tls_context *tls_ctx,
505 506 507
			     struct tls_sw_context_tx *ctx,
			     struct aead_request *aead_req,
			     size_t data_len)
D
Dave Watson 已提交
508
{
509
	struct tls_rec *rec = ctx->open_rec;
510 511
	struct scatterlist *plain_sg = rec->sg_plaintext_data;
	struct scatterlist *enc_sg = rec->sg_encrypted_data;
D
Dave Watson 已提交
512 513
	int rc;

514 515 516
	/* Skip the first index as it contains AAD data */
	rec->sg_encrypted_data[1].offset += tls_ctx->tx.prepend_size;
	rec->sg_encrypted_data[1].length -= tls_ctx->tx.prepend_size;
D
Dave Watson 已提交
517

518 519 520 521
	/* If it is inplace crypto, then pass same SG list as both src, dst */
	if (rec->inplace_crypto)
		plain_sg = enc_sg;

D
Dave Watson 已提交
522 523
	aead_request_set_tfm(aead_req, ctx->aead_send);
	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
524
	aead_request_set_crypt(aead_req, plain_sg, enc_sg,
525
			       data_len, tls_ctx->tx.iv);
526 527

	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
528 529
				  tls_encrypt_done, sk);

530 531
	/* Add the record in tx_list */
	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
532
	atomic_inc(&ctx->encrypt_pending);
533

534 535 536
	rc = crypto_aead_encrypt(aead_req);
	if (!rc || rc != -EINPROGRESS) {
		atomic_dec(&ctx->encrypt_pending);
537 538
		rec->sg_encrypted_data[1].offset -= tls_ctx->tx.prepend_size;
		rec->sg_encrypted_data[1].length += tls_ctx->tx.prepend_size;
539
	}
D
Dave Watson 已提交
540

541 542 543 544
	if (!rc) {
		WRITE_ONCE(rec->tx_ready, true);
	} else if (rc != -EINPROGRESS) {
		list_del(&rec->list);
545
		return rc;
546
	}
D
Dave Watson 已提交
547

548 549 550
	/* Unhook the record from context if encryption is not failure */
	ctx->open_rec = NULL;
	tls_advance_record_sn(sk, &tls_ctx->tx);
D
Dave Watson 已提交
551 552 553 554 555 556 557
	return rc;
}

static int tls_push_record(struct sock *sk, int flags,
			   unsigned char record_type)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
558
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
559
	struct tls_rec *rec = ctx->open_rec;
560
	struct aead_request *req;
D
Dave Watson 已提交
561 562
	int rc;

563 564
	if (!rec)
		return 0;
565

566 567
	rec->tx_flags = flags;
	req = &rec->aead_req;
D
Dave Watson 已提交
568

569 570
	sg_mark_end(rec->sg_plaintext_data + rec->sg_plaintext_num_elem);
	sg_mark_end(rec->sg_encrypted_data + rec->sg_encrypted_num_elem);
571 572

	tls_make_aad(rec->aad_space, rec->sg_plaintext_size,
573
		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
D
Dave Watson 已提交
574 575 576
		     record_type);

	tls_fill_prepend(tls_ctx,
577 578
			 page_address(sg_page(&rec->sg_encrypted_data[1])) +
			 rec->sg_encrypted_data[1].offset,
579
			 rec->sg_plaintext_size, record_type);
D
Dave Watson 已提交
580 581 582

	tls_ctx->pending_open_record_frags = 0;

583 584 585
	rc = tls_do_encryption(sk, tls_ctx, ctx, req, rec->sg_plaintext_size);
	if (rc == -EINPROGRESS)
		return -EINPROGRESS;
D
Dave Watson 已提交
586

587
	if (rc < 0) {
588
		tls_err_abort(sk, EBADMSG);
589 590
		return rc;
	}
D
Dave Watson 已提交
591

592
	return tls_tx_records(sk, flags);
D
Dave Watson 已提交
593 594 595 596 597 598 599 600
}

static int tls_sw_push_pending_record(struct sock *sk, int flags)
{
	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
}

static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
D
Dave Watson 已提交
601 602 603
			      int length, int *pages_used,
			      unsigned int *size_used,
			      struct scatterlist *to, int to_max_pages,
604
			      bool charge)
D
Dave Watson 已提交
605 606 607 608 609 610
{
	struct page *pages[MAX_SKB_FRAGS];

	size_t offset;
	ssize_t copied, use;
	int i = 0;
D
Dave Watson 已提交
611 612
	unsigned int size = *size_used;
	int num_elem = *pages_used;
D
Dave Watson 已提交
613 614 615 616 617
	int rc = 0;
	int maxpages;

	while (length > 0) {
		i = 0;
D
Dave Watson 已提交
618
		maxpages = to_max_pages - num_elem;
D
Dave Watson 已提交
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
		if (maxpages == 0) {
			rc = -EFAULT;
			goto out;
		}
		copied = iov_iter_get_pages(from, pages,
					    length,
					    maxpages, &offset);
		if (copied <= 0) {
			rc = -EFAULT;
			goto out;
		}

		iov_iter_advance(from, copied);

		length -= copied;
		size += copied;
		while (copied) {
			use = min_t(int, copied, PAGE_SIZE - offset);

D
Dave Watson 已提交
638
			sg_set_page(&to[num_elem],
D
Dave Watson 已提交
639
				    pages[i], use, offset);
D
Dave Watson 已提交
640 641 642
			sg_unmark_end(&to[num_elem]);
			if (charge)
				sk_mem_charge(sk, use);
D
Dave Watson 已提交
643 644 645 646 647 648 649 650 651

			offset = 0;
			copied -= use;

			++i;
			++num_elem;
		}
	}

652 653 654
	/* Mark the end in the last sg entry if newly added */
	if (num_elem > *pages_used)
		sg_mark_end(&to[num_elem - 1]);
D
Dave Watson 已提交
655
out:
656 657
	if (rc)
		iov_iter_revert(from, size - *size_used);
D
Dave Watson 已提交
658 659 660
	*size_used = size;
	*pages_used = num_elem;

D
Dave Watson 已提交
661 662 663 664 665 666 667
	return rc;
}

static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
			     int bytes)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
668
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
669
	struct tls_rec *rec = ctx->open_rec;
670
	struct scatterlist *sg = &rec->sg_plaintext_data[1];
D
Dave Watson 已提交
671 672 673
	int copy, i, rc = 0;

	for (i = tls_ctx->pending_open_record_frags;
674
	     i < rec->sg_plaintext_num_elem; ++i) {
D
Dave Watson 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
		copy = sg[i].length;
		if (copy_from_iter(
				page_address(sg_page(&sg[i])) + sg[i].offset,
				copy, from) != copy) {
			rc = -EFAULT;
			goto out;
		}
		bytes -= copy;

		++tls_ctx->pending_open_record_frags;

		if (!bytes)
			break;
	}

out:
	return rc;
}

694
static struct tls_rec *get_rec(struct sock *sk)
D
Dave Watson 已提交
695 696
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
697
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715
	struct tls_rec *rec;
	int mem_size;

	/* Return if we already have an open record */
	if (ctx->open_rec)
		return ctx->open_rec;

	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);

	rec = kzalloc(mem_size, sk->sk_allocation);
	if (!rec)
		return NULL;

	sg_init_table(&rec->sg_plaintext_data[0],
		      ARRAY_SIZE(rec->sg_plaintext_data));
	sg_init_table(&rec->sg_encrypted_data[0],
		      ARRAY_SIZE(rec->sg_encrypted_data));

716
	sg_set_buf(&rec->sg_plaintext_data[0], rec->aad_space,
717
		   sizeof(rec->aad_space));
718
	sg_set_buf(&rec->sg_encrypted_data[0], rec->aad_space,
719 720 721
		   sizeof(rec->aad_space));

	ctx->open_rec = rec;
722
	rec->inplace_crypto = 1;
723 724 725 726 727 728

	return rec;
}

int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
D
Dave Watson 已提交
729
	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
730 731 732 733 734 735
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send);
	bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
D
Dave Watson 已提交
736 737
	bool eor = !(msg->msg_flags & MSG_MORE);
	size_t try_to_copy, copied = 0;
738 739 740
	struct tls_rec *rec;
	int required_size;
	int num_async = 0;
D
Dave Watson 已提交
741
	bool full_record;
742 743
	int record_room;
	int num_zc = 0;
D
Dave Watson 已提交
744
	int orig_size;
745
	int ret = 0;
D
Dave Watson 已提交
746 747 748 749 750 751

	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
		return -ENOTSUPP;

	lock_sock(sk);

752 753 754 755 756 757
	/* Wait till there is any pending write on socket */
	if (unlikely(sk->sk_write_pending)) {
		ret = wait_on_pending_writer(sk, &timeo);
		if (unlikely(ret))
			goto send_end;
	}
D
Dave Watson 已提交
758 759 760

	if (unlikely(msg->msg_controllen)) {
		ret = tls_proccess_cmsg(sk, msg, &record_type);
761 762 763 764 765 766
		if (ret) {
			if (ret == -EINPROGRESS)
				num_async++;
			else if (ret != -EAGAIN)
				goto send_end;
		}
D
Dave Watson 已提交
767 768 769 770
	}

	while (msg_data_left(msg)) {
		if (sk->sk_err) {
771
			ret = -sk->sk_err;
D
Dave Watson 已提交
772 773 774
			goto send_end;
		}

775 776 777 778 779 780 781
		rec = get_rec(sk);
		if (!rec) {
			ret = -ENOMEM;
			goto send_end;
		}

		orig_size = rec->sg_plaintext_size;
D
Dave Watson 已提交
782 783
		full_record = false;
		try_to_copy = msg_data_left(msg);
784
		record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
D
Dave Watson 已提交
785 786 787 788 789
		if (try_to_copy >= record_room) {
			try_to_copy = record_room;
			full_record = true;
		}

790
		required_size = rec->sg_plaintext_size + try_to_copy +
791
				tls_ctx->tx.overhead_size;
D
Dave Watson 已提交
792 793 794

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
795

D
Dave Watson 已提交
796 797 798 799 800 801 802 803 804 805
alloc_encrypted:
		ret = alloc_encrypted_sg(sk, required_size);
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust try_to_copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
806
			try_to_copy -= required_size - rec->sg_encrypted_size;
D
Dave Watson 已提交
807 808
			full_record = true;
		}
809 810

		if (!is_kvec && (full_record || eor) && !async_capable) {
D
Dave Watson 已提交
811
			ret = zerocopy_from_iter(sk, &msg->msg_iter,
812 813
				try_to_copy, &rec->sg_plaintext_num_elem,
				&rec->sg_plaintext_size,
814 815
				&rec->sg_plaintext_data[1],
				ARRAY_SIZE(rec->sg_plaintext_data) - 1,
816
				true);
D
Dave Watson 已提交
817 818 819
			if (ret)
				goto fallback_to_reg_send;

820 821
			rec->inplace_crypto = 0;

822
			num_zc++;
D
Dave Watson 已提交
823 824
			copied += try_to_copy;
			ret = tls_push_record(sk, msg->msg_flags, record_type);
825 826 827 828 829 830
			if (ret) {
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret != -EAGAIN)
					goto send_end;
			}
831
			continue;
D
Dave Watson 已提交
832 833

fallback_to_reg_send:
834
			trim_sg(sk, &rec->sg_plaintext_data[1],
835 836
				&rec->sg_plaintext_num_elem,
				&rec->sg_plaintext_size,
D
Dave Watson 已提交
837 838 839
				orig_size);
		}

840
		required_size = rec->sg_plaintext_size + try_to_copy;
841 842

		ret = move_to_plaintext_sg(sk, required_size);
D
Dave Watson 已提交
843 844
		if (ret) {
			if (ret != -ENOSPC)
845
				goto send_end;
D
Dave Watson 已提交
846 847 848 849 850

			/* Adjust try_to_copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
851
			try_to_copy -= required_size - rec->sg_plaintext_size;
D
Dave Watson 已提交
852 853
			full_record = true;

854
			trim_sg(sk, &rec->sg_encrypted_data[1],
855 856 857
				&rec->sg_encrypted_num_elem,
				&rec->sg_encrypted_size,
				rec->sg_plaintext_size +
858
				tls_ctx->tx.overhead_size);
D
Dave Watson 已提交
859 860 861 862 863 864 865 866 867 868
		}

		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
		if (ret)
			goto trim_sgl;

		copied += try_to_copy;
		if (full_record || eor) {
			ret = tls_push_record(sk, msg->msg_flags, record_type);
			if (ret) {
869 870 871 872
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret != -EAGAIN)
					goto send_end;
D
Dave Watson 已提交
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887
			}
		}

		continue;

wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
trim_sgl:
			trim_both_sgl(sk, orig_size);
			goto send_end;
		}

888
		if (rec->sg_encrypted_size < required_size)
D
Dave Watson 已提交
889 890 891
			goto alloc_encrypted;
	}

892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916
	if (!num_async) {
		goto send_end;
	} else if (num_zc) {
		/* Wait for pending encryptions to get completed */
		smp_store_mb(ctx->async_notify, true);

		if (atomic_read(&ctx->encrypt_pending))
			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
		else
			reinit_completion(&ctx->async_wait.completion);

		WRITE_ONCE(ctx->async_notify, false);

		if (ctx->async_wait.err) {
			ret = ctx->async_wait.err;
			copied = 0;
		}
	}

	/* Transmit if any encryptions have completed */
	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
		cancel_delayed_work(&ctx->tx_work.work);
		tls_tx_records(sk, msg->msg_flags);
	}

D
Dave Watson 已提交
917 918 919 920 921 922 923 924 925 926
send_end:
	ret = sk_stream_error(sk, msg->msg_flags, ret);

	release_sock(sk);
	return copied ? copied : ret;
}

int tls_sw_sendpage(struct sock *sk, struct page *page,
		    int offset, size_t size, int flags)
{
927
	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
D
Dave Watson 已提交
928
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
929
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
D
Dave Watson 已提交
930
	unsigned char record_type = TLS_RECORD_TYPE_DATA;
931
	size_t orig_size = size;
D
Dave Watson 已提交
932
	struct scatterlist *sg;
933 934
	struct tls_rec *rec;
	int num_async = 0;
D
Dave Watson 已提交
935 936
	bool full_record;
	int record_room;
937
	int ret = 0;
938
	bool eor;
D
Dave Watson 已提交
939 940 941 942 943 944 945 946 947 948 949 950

	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
		      MSG_SENDPAGE_NOTLAST))
		return -ENOTSUPP;

	/* No MSG_EOR from splice, only look at MSG_MORE */
	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));

	lock_sock(sk);

	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);

951 952 953 954 955 956
	/* Wait till there is any pending write on socket */
	if (unlikely(sk->sk_write_pending)) {
		ret = wait_on_pending_writer(sk, &timeo);
		if (unlikely(ret))
			goto sendpage_end;
	}
D
Dave Watson 已提交
957 958 959 960 961 962

	/* Call the sk_stream functions to manage the sndbuf mem. */
	while (size > 0) {
		size_t copy, required_size;

		if (sk->sk_err) {
963
			ret = -sk->sk_err;
D
Dave Watson 已提交
964 965 966
			goto sendpage_end;
		}

967 968 969 970 971 972
		rec = get_rec(sk);
		if (!rec) {
			ret = -ENOMEM;
			goto sendpage_end;
		}

D
Dave Watson 已提交
973
		full_record = false;
974
		record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
D
Dave Watson 已提交
975 976 977 978 979
		copy = size;
		if (copy >= record_room) {
			copy = record_room;
			full_record = true;
		}
980
		required_size = rec->sg_plaintext_size + copy +
981
			      tls_ctx->tx.overhead_size;
D
Dave Watson 已提交
982 983 984 985 986 987 988 989 990 991 992 993 994

		if (!sk_stream_memory_free(sk))
			goto wait_for_sndbuf;
alloc_payload:
		ret = alloc_encrypted_sg(sk, required_size);
		if (ret) {
			if (ret != -ENOSPC)
				goto wait_for_memory;

			/* Adjust copy according to the amount that was
			 * actually allocated. The difference is due
			 * to max sg elements limit
			 */
995
			copy -= required_size - rec->sg_plaintext_size;
D
Dave Watson 已提交
996 997 998 999
			full_record = true;
		}

		get_page(page);
1000
		sg = &rec->sg_plaintext_data[1] + rec->sg_plaintext_num_elem;
D
Dave Watson 已提交
1001
		sg_set_page(sg, page, copy, offset);
1002 1003
		sg_unmark_end(sg);

1004
		rec->sg_plaintext_num_elem++;
D
Dave Watson 已提交
1005 1006 1007 1008

		sk_mem_charge(sk, copy);
		offset += copy;
		size -= copy;
1009 1010
		rec->sg_plaintext_size += copy;
		tls_ctx->pending_open_record_frags = rec->sg_plaintext_num_elem;
D
Dave Watson 已提交
1011 1012

		if (full_record || eor ||
1013
		    rec->sg_plaintext_num_elem ==
1014
		    ARRAY_SIZE(rec->sg_plaintext_data) - 1) {
1015
			rec->inplace_crypto = 0;
D
Dave Watson 已提交
1016 1017
			ret = tls_push_record(sk, flags, record_type);
			if (ret) {
1018 1019 1020 1021
				if (ret == -EINPROGRESS)
					num_async++;
				else if (ret != -EAGAIN)
					goto sendpage_end;
D
Dave Watson 已提交
1022 1023 1024 1025 1026 1027 1028 1029
			}
		}
		continue;
wait_for_sndbuf:
		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
wait_for_memory:
		ret = sk_stream_wait_memory(sk, &timeo);
		if (ret) {
1030
			trim_both_sgl(sk, rec->sg_plaintext_size);
D
Dave Watson 已提交
1031 1032 1033 1034 1035 1036
			goto sendpage_end;
		}

		goto alloc_payload;
	}

1037 1038 1039 1040 1041 1042 1043
	if (num_async) {
		/* Transmit if any encryptions have completed */
		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
			cancel_delayed_work(&ctx->tx_work.work);
			tls_tx_records(sk, flags);
		}
	}
D
Dave Watson 已提交
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
sendpage_end:
	if (orig_size > size)
		ret = orig_size - size;
	else
		ret = sk_stream_error(sk, flags, ret);

	release_sock(sk);
	return ret;
}

D
Dave Watson 已提交
1054 1055 1056 1057
static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
				     long timeo, int *err)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1058
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1059 1060 1061 1062 1063 1064 1065 1066 1067
	struct sk_buff *skb;
	DEFINE_WAIT_FUNC(wait, woken_wake_function);

	while (!(skb = ctx->recv_pkt)) {
		if (sk->sk_err) {
			*err = sock_error(sk);
			return NULL;
		}

1068 1069 1070
		if (sk->sk_shutdown & RCV_SHUTDOWN)
			return NULL;

D
Dave Watson 已提交
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
		if (sock_flag(sk, SOCK_DONE))
			return NULL;

		if ((flags & MSG_DONTWAIT) || !timeo) {
			*err = -EAGAIN;
			return NULL;
		}

		add_wait_queue(sk_sleep(sk), &wait);
		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
		remove_wait_queue(sk_sleep(sk), &wait);

		/* Handle signals */
		if (signal_pending(current)) {
			*err = sock_intr_errno(timeo);
			return NULL;
		}
	}

	return skb;
}

1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123
/* This function decrypts the input skb into either out_iov or in out_sg
 * or in skb buffers itself. The input parameter 'zc' indicates if
 * zero-copy mode needs to be tried or not. With zero-copy mode, either
 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
 * NULL, then the decryption happens inside skb buffers itself, i.e.
 * zero-copy gets disabled and 'zc' is updated.
 */

static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			    struct iov_iter *out_iov,
			    struct scatterlist *out_sg,
			    int *chunk, bool *zc)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct strp_msg *rxm = strp_msg(skb);
	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
	struct aead_request *aead_req;
	struct sk_buff *unused;
	u8 *aad, *iv, *mem = NULL;
	struct scatterlist *sgin = NULL;
	struct scatterlist *sgout = NULL;
	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;

	if (*zc && (out_iov || out_sg)) {
		if (out_iov)
			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
		else
			n_sgout = sg_nents(out_sg);
1124 1125
		n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
				 rxm->full_len - tls_ctx->rx.prepend_size);
1126 1127 1128
	} else {
		n_sgout = 0;
		*zc = false;
1129
		n_sgin = skb_cow_data(skb, 0, &unused);
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
	}

	if (n_sgin < 1)
		return -EBADMSG;

	/* Increment to accommodate AAD */
	n_sgin = n_sgin + 1;

	nsg = n_sgin + n_sgout;

	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);

	/* Allocate a single block of memory which contains
	 * aead_req || sgin[] || sgout[] || aad || iv.
	 * This order achieves correct alignment for aead_req, sgin, sgout.
	 */
	mem = kmalloc(mem_size, sk->sk_allocation);
	if (!mem)
		return -ENOMEM;

	/* Segment the allocated memory */
	aead_req = (struct aead_request *)mem;
	sgin = (struct scatterlist *)(mem + aead_size);
	sgout = sgin + n_sgin;
	aad = (u8 *)(sgout + n_sgout);
	iv = aad + TLS_AAD_SPACE_SIZE;

	/* Prepare IV */
	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
			    tls_ctx->rx.iv_size);
	if (err < 0) {
		kfree(mem);
		return err;
	}
	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);

	/* Prepare AAD */
	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
		     ctx->control);

	/* Prepare sgin */
	sg_init_table(sgin, n_sgin);
	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
	err = skb_to_sgvec(skb, &sgin[1],
			   rxm->offset + tls_ctx->rx.prepend_size,
			   rxm->full_len - tls_ctx->rx.prepend_size);
	if (err < 0) {
		kfree(mem);
		return err;
	}

	if (n_sgout) {
		if (out_iov) {
			sg_init_table(sgout, n_sgout);
			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);

			*chunk = 0;
			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
						 chunk, &sgout[1],
						 (n_sgout - 1), false);
			if (err < 0)
				goto fallback_to_reg_recv;
		} else if (out_sg) {
			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
		} else {
			goto fallback_to_reg_recv;
		}
	} else {
fallback_to_reg_recv:
		sgout = sgin;
		pages = 0;
		*chunk = 0;
		*zc = false;
	}

	/* Prepare and submit AEAD request */
1211 1212 1213 1214
	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
				data_len, aead_req, *zc);
	if (err == -EINPROGRESS)
		return err;
1215 1216 1217 1218 1219 1220 1221 1222 1223

	/* Release the pages in case iov was mapped to pages */
	for (; pages > 0; pages--)
		put_page(sg_page(&sgout[pages]));

	kfree(mem);
	return err;
}

1224
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1225
			      struct iov_iter *dest, int *chunk, bool *zc)
1226 1227 1228 1229 1230 1231
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct strp_msg *rxm = strp_msg(skb);
	int err = 0;

1232 1233
#ifdef CONFIG_TLS_DEVICE
	err = tls_device_decrypted(sk, skb);
1234 1235
	if (err < 0)
		return err;
1236 1237
#endif
	if (!ctx->decrypted) {
1238
		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
1239 1240 1241 1242
		if (err < 0) {
			if (err == -EINPROGRESS)
				tls_advance_record_sn(sk, &tls_ctx->rx);

1243
			return err;
1244
		}
1245 1246 1247
	} else {
		*zc = false;
	}
1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259

	rxm->offset += tls_ctx->rx.prepend_size;
	rxm->full_len -= tls_ctx->rx.overhead_size;
	tls_advance_record_sn(sk, &tls_ctx->rx);
	ctx->decrypted = true;
	ctx->saved_data_ready(sk);

	return err;
}

int decrypt_skb(struct sock *sk, struct sk_buff *skb,
		struct scatterlist *sgout)
D
Dave Watson 已提交
1260
{
1261 1262
	bool zc = true;
	int chunk;
D
Dave Watson 已提交
1263

1264
	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
D
Dave Watson 已提交
1265 1266 1267 1268 1269 1270
}

static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
			       unsigned int len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1271
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1272

1273 1274
	if (skb) {
		struct strp_msg *rxm = strp_msg(skb);
D
Dave Watson 已提交
1275

1276 1277 1278 1279 1280 1281
		if (len < rxm->full_len) {
			rxm->offset += len;
			rxm->full_len -= len;
			return false;
		}
		kfree_skb(skb);
D
Dave Watson 已提交
1282 1283 1284 1285
	}

	/* Finished with message */
	ctx->recv_pkt = NULL;
1286
	__strp_unpause(&ctx->strp);
D
Dave Watson 已提交
1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298

	return true;
}

int tls_sw_recvmsg(struct sock *sk,
		   struct msghdr *msg,
		   size_t len,
		   int nonblock,
		   int flags,
		   int *addr_len)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1299
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1300 1301 1302 1303 1304
	unsigned char control;
	struct strp_msg *rxm;
	struct sk_buff *skb;
	ssize_t copied = 0;
	bool cmsg = false;
1305
	int target, err = 0;
D
Dave Watson 已提交
1306
	long timeo;
1307
	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1308
	int num_async = 0;
D
Dave Watson 已提交
1309 1310 1311 1312 1313 1314 1315 1316

	flags |= nonblock;

	if (unlikely(flags & MSG_ERRQUEUE))
		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);

	lock_sock(sk);

1317
	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
D
Dave Watson 已提交
1318 1319 1320
	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
	do {
		bool zc = false;
1321
		bool async = false;
D
Dave Watson 已提交
1322 1323 1324 1325 1326 1327 1328
		int chunk = 0;

		skb = tls_wait_data(sk, flags, timeo, &err);
		if (!skb)
			goto recv_end;

		rxm = strp_msg(skb);
1329

D
Dave Watson 已提交
1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347
		if (!cmsg) {
			int cerr;

			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
					sizeof(ctx->control), &ctx->control);
			cmsg = true;
			control = ctx->control;
			if (ctx->control != TLS_RECORD_TYPE_DATA) {
				if (cerr || msg->msg_flags & MSG_CTRUNC) {
					err = -EIO;
					goto recv_end;
				}
			}
		} else if (control != ctx->control) {
			goto recv_end;
		}

		if (!ctx->decrypted) {
1348
			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
D
Dave Watson 已提交
1349

1350 1351
			if (!is_kvec && to_copy <= len &&
			    likely(!(flags & MSG_PEEK)))
D
Dave Watson 已提交
1352
				zc = true;
1353 1354 1355

			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
						 &chunk, &zc);
1356
			if (err < 0 && err != -EINPROGRESS) {
1357 1358
				tls_err_abort(sk, EBADMSG);
				goto recv_end;
D
Dave Watson 已提交
1359
			}
1360 1361 1362 1363 1364 1365 1366

			if (err == -EINPROGRESS) {
				async = true;
				num_async++;
				goto pick_next_record;
			}

D
Dave Watson 已提交
1367 1368 1369 1370 1371
			ctx->decrypted = true;
		}

		if (!zc) {
			chunk = min_t(unsigned int, rxm->full_len, len);
1372

D
Dave Watson 已提交
1373 1374 1375 1376 1377 1378
			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
						    chunk);
			if (err < 0)
				goto recv_end;
		}

1379
pick_next_record:
D
Dave Watson 已提交
1380 1381 1382 1383 1384
		copied += chunk;
		len -= chunk;
		if (likely(!(flags & MSG_PEEK))) {
			u8 control = ctx->control;

1385 1386 1387 1388
			/* For async, drop current skb reference */
			if (async)
				skb = NULL;

D
Dave Watson 已提交
1389 1390 1391 1392 1393 1394 1395 1396
			if (tls_sw_advance_skb(sk, skb, chunk)) {
				/* Return full control message to
				 * userspace before trying to parse
				 * another message type
				 */
				msg->msg_flags |= MSG_EOR;
				if (control != TLS_RECORD_TYPE_DATA)
					goto recv_end;
1397 1398
			} else {
				break;
D
Dave Watson 已提交
1399
			}
1400 1401 1402 1403 1404 1405 1406
		} else {
			/* MSG_PEEK right now cannot look beyond current skb
			 * from strparser, meaning we cannot advance skb here
			 * and thus unpause strparser since we'd loose original
			 * one.
			 */
			break;
D
Dave Watson 已提交
1407
		}
1408

1409 1410 1411
		/* If we have a new message from strparser, continue now. */
		if (copied >= target && !ctx->recv_pkt)
			break;
D
Dave Watson 已提交
1412 1413 1414
	} while (len);

recv_end:
1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430
	if (num_async) {
		/* Wait for all previously submitted records to be decrypted */
		smp_store_mb(ctx->async_notify, true);
		if (atomic_read(&ctx->decrypt_pending)) {
			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
			if (err) {
				/* one of async decrypt failed */
				tls_err_abort(sk, err);
				copied = 0;
			}
		} else {
			reinit_completion(&ctx->async_wait.completion);
		}
		WRITE_ONCE(ctx->async_notify, false);
	}

D
Dave Watson 已提交
1431 1432 1433 1434 1435 1436 1437 1438 1439
	release_sock(sk);
	return copied ? : err;
}

ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
			   struct pipe_inode_info *pipe,
			   size_t len, unsigned int flags)
{
	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
B
Boris Pismenny 已提交
1440
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1441 1442 1443 1444 1445 1446 1447
	struct strp_msg *rxm = NULL;
	struct sock *sk = sock->sk;
	struct sk_buff *skb;
	ssize_t copied = 0;
	int err = 0;
	long timeo;
	int chunk;
1448
	bool zc = false;
D
Dave Watson 已提交
1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464

	lock_sock(sk);

	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);

	skb = tls_wait_data(sk, flags, timeo, &err);
	if (!skb)
		goto splice_read_end;

	/* splice does not support reading control messages */
	if (ctx->control != TLS_RECORD_TYPE_DATA) {
		err = -ENOTSUPP;
		goto splice_read_end;
	}

	if (!ctx->decrypted) {
1465
		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
D
Dave Watson 已提交
1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487

		if (err < 0) {
			tls_err_abort(sk, EBADMSG);
			goto splice_read_end;
		}
		ctx->decrypted = true;
	}
	rxm = strp_msg(skb);

	chunk = min_t(unsigned int, rxm->full_len, len);
	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
	if (copied < 0)
		goto splice_read_end;

	if (likely(!(flags & MSG_PEEK)))
		tls_sw_advance_skb(sk, skb, copied);

splice_read_end:
	release_sock(sk);
	return copied ? : err;
}

1488 1489
unsigned int tls_sw_poll(struct file *file, struct socket *sock,
			 struct poll_table_struct *wait)
D
Dave Watson 已提交
1490
{
1491
	unsigned int ret;
D
Dave Watson 已提交
1492 1493
	struct sock *sk = sock->sk;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1494
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1495

1496 1497
	/* Grab POLLOUT and POLLHUP from the underlying socket */
	ret = ctx->sk_poll(file, sock, wait);
D
Dave Watson 已提交
1498

1499 1500
	/* Clear POLLIN bits, and set based on recv_pkt */
	ret &= ~(POLLIN | POLLRDNORM);
D
Dave Watson 已提交
1501
	if (ctx->recv_pkt)
1502
		ret |= POLLIN | POLLRDNORM;
D
Dave Watson 已提交
1503

1504
	return ret;
D
Dave Watson 已提交
1505 1506 1507 1508 1509
}

static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
B
Boris Pismenny 已提交
1510
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
K
Kees Cook 已提交
1511
	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
D
Dave Watson 已提交
1512 1513 1514 1515 1516 1517 1518 1519 1520
	struct strp_msg *rxm = strp_msg(skb);
	size_t cipher_overhead;
	size_t data_len = 0;
	int ret;

	/* Verify that we have a full TLS header, or wait for more data */
	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
		return 0;

K
Kees Cook 已提交
1521 1522 1523 1524 1525 1526
	/* Sanity-check size of on-stack buffer. */
	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
		ret = -EINVAL;
		goto read_failure;
	}

D
Dave Watson 已提交
1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547
	/* Linearize header to local buffer */
	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);

	if (ret < 0)
		goto read_failure;

	ctx->control = header[0];

	data_len = ((header[4] & 0xFF) | (header[3] << 8));

	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;

	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
		ret = -EMSGSIZE;
		goto read_failure;
	}
	if (data_len < cipher_overhead) {
		ret = -EBADMSG;
		goto read_failure;
	}

1548 1549
	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
D
Dave Watson 已提交
1550 1551 1552 1553
		ret = -EINVAL;
		goto read_failure;
	}

1554 1555 1556 1557
#ifdef CONFIG_TLS_DEVICE
	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
			     *(u64*)tls_ctx->rx.rec_seq);
#endif
D
Dave Watson 已提交
1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568
	return data_len + TLS_HEADER_SIZE;

read_failure:
	tls_err_abort(strp->sk, ret);

	return ret;
}

static void tls_queue(struct strparser *strp, struct sk_buff *skb)
{
	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
B
Boris Pismenny 已提交
1569
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1570 1571 1572 1573 1574 1575

	ctx->decrypted = false;

	ctx->recv_pkt = skb;
	strp_pause(strp);

1576
	ctx->saved_data_ready(strp->sk);
D
Dave Watson 已提交
1577 1578 1579 1580 1581
}

static void tls_data_ready(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1582
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
D
Dave Watson 已提交
1583 1584 1585 1586

	strp_data_ready(&ctx->strp);
}

B
Boris Pismenny 已提交
1587
void tls_sw_free_resources_tx(struct sock *sk)
D
Dave Watson 已提交
1588 1589
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
B
Boris Pismenny 已提交
1590
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602
	struct tls_rec *rec, *tmp;

	/* Wait for any pending async encryptions to complete */
	smp_store_mb(ctx->async_notify, true);
	if (atomic_read(&ctx->encrypt_pending))
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);

	cancel_delayed_work_sync(&ctx->tx_work.work);

	/* Tx whatever records we can transmit and abandon the rest */
	tls_tx_records(sk, -1);

1603
	/* Free up un-sent records in tx_list. First, free
1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619
	 * the partially sent record if any at head of tx_list.
	 */
	if (tls_ctx->partially_sent_record) {
		struct scatterlist *sg = tls_ctx->partially_sent_record;

		while (1) {
			put_page(sg_page(sg));
			sk_mem_uncharge(sk, sg->length);

			if (sg_is_last(sg))
				break;
			sg++;
		}

		tls_ctx->partially_sent_record = NULL;

1620
		rec = list_first_entry(&ctx->tx_list,
1621
				       struct tls_rec, list);
1622

1623
		free_sg(sk, &rec->sg_plaintext_data[1],
1624 1625 1626
			&rec->sg_plaintext_num_elem,
			&rec->sg_plaintext_size);

1627 1628 1629 1630
		list_del(&rec->list);
		kfree(rec);
	}

1631
	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
1632
		free_sg(sk, &rec->sg_encrypted_data[1],
1633 1634 1635
			&rec->sg_encrypted_num_elem,
			&rec->sg_encrypted_size);

1636
		free_sg(sk, &rec->sg_plaintext_data[1],
1637 1638 1639
			&rec->sg_plaintext_num_elem,
			&rec->sg_plaintext_size);

1640 1641 1642
		list_del(&rec->list);
		kfree(rec);
	}
D
Dave Watson 已提交
1643

1644
	crypto_free_aead(ctx->aead_send);
1645
	tls_free_open_rec(sk);
B
Boris Pismenny 已提交
1646 1647 1648 1649

	kfree(ctx);
}

1650
void tls_sw_release_resources_rx(struct sock *sk)
B
Boris Pismenny 已提交
1651 1652 1653 1654
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

D
Dave Watson 已提交
1655
	if (ctx->aead_recv) {
1656 1657
		kfree_skb(ctx->recv_pkt);
		ctx->recv_pkt = NULL;
D
Dave Watson 已提交
1658 1659 1660 1661 1662 1663 1664 1665 1666
		crypto_free_aead(ctx->aead_recv);
		strp_stop(&ctx->strp);
		write_lock_bh(&sk->sk_callback_lock);
		sk->sk_data_ready = ctx->saved_data_ready;
		write_unlock_bh(&sk->sk_callback_lock);
		release_sock(sk);
		strp_done(&ctx->strp);
		lock_sock(sk);
	}
1667 1668 1669 1670 1671 1672 1673 1674
}

void tls_sw_free_resources_rx(struct sock *sk)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);

	tls_sw_release_resources_rx(sk);
D
Dave Watson 已提交
1675 1676 1677 1678

	kfree(ctx);
}

1679
/* The work handler to transmitt the encrypted records in tx_list */
1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696
static void tx_work_handler(struct work_struct *work)
{
	struct delayed_work *delayed_work = to_delayed_work(work);
	struct tx_work *tx_work = container_of(delayed_work,
					       struct tx_work, work);
	struct sock *sk = tx_work->sk;
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);

	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
		return;

	lock_sock(sk);
	tls_tx_records(sk, -1);
	release_sock(sk);
}

D
Dave Watson 已提交
1697
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
D
Dave Watson 已提交
1698 1699 1700
{
	struct tls_crypto_info *crypto_info;
	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
B
Boris Pismenny 已提交
1701 1702
	struct tls_sw_context_tx *sw_ctx_tx = NULL;
	struct tls_sw_context_rx *sw_ctx_rx = NULL;
D
Dave Watson 已提交
1703 1704 1705
	struct cipher_context *cctx;
	struct crypto_aead **aead;
	struct strp_callbacks cb;
D
Dave Watson 已提交
1706 1707 1708 1709 1710 1711 1712 1713 1714
	u16 nonce_size, tag_size, iv_size, rec_seq_size;
	char *iv, *rec_seq;
	int rc = 0;

	if (!ctx) {
		rc = -EINVAL;
		goto out;
	}

B
Boris Pismenny 已提交
1715
	if (tx) {
1716 1717 1718 1719 1720 1721 1722 1723 1724 1725
		if (!ctx->priv_ctx_tx) {
			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
			if (!sw_ctx_tx) {
				rc = -ENOMEM;
				goto out;
			}
			ctx->priv_ctx_tx = sw_ctx_tx;
		} else {
			sw_ctx_tx =
				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
D
Dave Watson 已提交
1726 1727
		}
	} else {
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737
		if (!ctx->priv_ctx_rx) {
			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
			if (!sw_ctx_rx) {
				rc = -ENOMEM;
				goto out;
			}
			ctx->priv_ctx_rx = sw_ctx_rx;
		} else {
			sw_ctx_rx =
				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
B
Boris Pismenny 已提交
1738
		}
D
Dave Watson 已提交
1739 1740
	}

D
Dave Watson 已提交
1741
	if (tx) {
1742
		crypto_init_wait(&sw_ctx_tx->async_wait);
1743
		crypto_info = &ctx->crypto_send.info;
D
Dave Watson 已提交
1744
		cctx = &ctx->tx;
B
Boris Pismenny 已提交
1745
		aead = &sw_ctx_tx->aead_send;
1746
		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
1747 1748
		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
		sw_ctx_tx->tx_work.sk = sk;
D
Dave Watson 已提交
1749
	} else {
1750
		crypto_init_wait(&sw_ctx_rx->async_wait);
1751
		crypto_info = &ctx->crypto_recv.info;
D
Dave Watson 已提交
1752
		cctx = &ctx->rx;
B
Boris Pismenny 已提交
1753
		aead = &sw_ctx_rx->aead_recv;
D
Dave Watson 已提交
1754 1755
	}

D
Dave Watson 已提交
1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770
	switch (crypto_info->cipher_type) {
	case TLS_CIPHER_AES_GCM_128: {
		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
		rec_seq =
		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
		gcm_128_info =
			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
		break;
	}
	default:
		rc = -EINVAL;
S
Sabrina Dubroca 已提交
1771
		goto free_priv;
D
Dave Watson 已提交
1772 1773
	}

K
Kees Cook 已提交
1774
	/* Sanity-check the IV size for stack allocations. */
K
Kees Cook 已提交
1775
	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
K
Kees Cook 已提交
1776 1777 1778 1779
		rc = -EINVAL;
		goto free_priv;
	}

D
Dave Watson 已提交
1780 1781 1782 1783 1784 1785 1786
	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
	cctx->tag_size = tag_size;
	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
	cctx->iv_size = iv_size;
	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
			   GFP_KERNEL);
	if (!cctx->iv) {
D
Dave Watson 已提交
1787
		rc = -ENOMEM;
S
Sabrina Dubroca 已提交
1788
		goto free_priv;
D
Dave Watson 已提交
1789
	}
D
Dave Watson 已提交
1790 1791 1792
	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
	cctx->rec_seq_size = rec_seq_size;
1793
	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
D
Dave Watson 已提交
1794
	if (!cctx->rec_seq) {
D
Dave Watson 已提交
1795 1796 1797
		rc = -ENOMEM;
		goto free_iv;
	}
D
Dave Watson 已提交
1798 1799 1800 1801 1802 1803

	if (!*aead) {
		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
		if (IS_ERR(*aead)) {
			rc = PTR_ERR(*aead);
			*aead = NULL;
D
Dave Watson 已提交
1804 1805 1806 1807 1808 1809
			goto free_rec_seq;
		}
	}

	ctx->push_pending_record = tls_sw_push_pending_record;

1810
	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
D
Dave Watson 已提交
1811 1812 1813 1814
				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
	if (rc)
		goto free_aead;

D
Dave Watson 已提交
1815 1816 1817 1818
	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
	if (rc)
		goto free_aead;

B
Boris Pismenny 已提交
1819
	if (sw_ctx_rx) {
D
Dave Watson 已提交
1820 1821 1822 1823 1824
		/* Set up strparser */
		memset(&cb, 0, sizeof(cb));
		cb.rcv_msg = tls_queue;
		cb.parse_msg = tls_read_size;

B
Boris Pismenny 已提交
1825
		strp_init(&sw_ctx_rx->strp, sk, &cb);
D
Dave Watson 已提交
1826 1827

		write_lock_bh(&sk->sk_callback_lock);
B
Boris Pismenny 已提交
1828
		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
D
Dave Watson 已提交
1829 1830 1831
		sk->sk_data_ready = tls_data_ready;
		write_unlock_bh(&sk->sk_callback_lock);

1832
		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
D
Dave Watson 已提交
1833

B
Boris Pismenny 已提交
1834
		strp_check_rcv(&sw_ctx_rx->strp);
D
Dave Watson 已提交
1835 1836 1837
	}

	goto out;
D
Dave Watson 已提交
1838 1839

free_aead:
D
Dave Watson 已提交
1840 1841
	crypto_free_aead(*aead);
	*aead = NULL;
D
Dave Watson 已提交
1842
free_rec_seq:
D
Dave Watson 已提交
1843 1844
	kfree(cctx->rec_seq);
	cctx->rec_seq = NULL;
D
Dave Watson 已提交
1845
free_iv:
B
Boris Pismenny 已提交
1846 1847
	kfree(cctx->iv);
	cctx->iv = NULL;
S
Sabrina Dubroca 已提交
1848
free_priv:
B
Boris Pismenny 已提交
1849 1850 1851 1852 1853 1854 1855
	if (tx) {
		kfree(ctx->priv_ctx_tx);
		ctx->priv_ctx_tx = NULL;
	} else {
		kfree(ctx->priv_ctx_rx);
		ctx->priv_ctx_rx = NULL;
	}
D
Dave Watson 已提交
1856 1857 1858
out:
	return rc;
}