diff --git a/include/net/tls.h b/include/net/tls.h index b89d397dd62fc7268a792b7f6d58b070db8cee0a..70becd0a92996e7e7a30121af50a847e019ca186 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -83,6 +83,8 @@ struct tls_context { void *priv_ctx; + u8 tx_conf:2; + u16 prepend_size; u16 tag_size; u16 overhead_size; @@ -97,7 +99,6 @@ struct tls_context { u16 pending_open_record_frags; int (*push_pending_record)(struct sock *sk, int flags); - void (*free_resources)(struct sock *sk); void (*sk_write_space)(struct sock *sk); void (*sk_proto_close)(struct sock *sk, long timeout); @@ -122,6 +123,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags); void tls_sw_close(struct sock *sk, long timeout); +void tls_sw_free_tx_resources(struct sock *sk); void tls_sk_destruct(struct sock *sk, struct tls_context *ctx); void tls_icsk_clean_acked(struct sock *sk); @@ -212,6 +214,21 @@ static inline void tls_fill_prepend(struct tls_context *ctx, ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv_size); } +static inline void tls_make_aad(char *buf, + size_t size, + char *record_sequence, + int record_sequence_size, + unsigned char record_type) +{ + memcpy(buf, record_sequence, record_sequence_size); + + buf[8] = record_type; + buf[9] = TLS_1_2_VERSION_MAJOR; + buf[10] = TLS_1_2_VERSION_MINOR; + buf[11] = size >> 8; + buf[12] = size & 0xFF; +} + static inline struct tls_context *tls_get_ctx(const struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 60aff60e30ad41380b58fd83b1b79ae5a876513e..e07ee3ae002300932f8fd41a907cee2fc042738f 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -45,8 +45,18 @@ MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); -static struct proto tls_base_prot; -static struct proto tls_sw_prot; +enum { + TLS_BASE_TX, + TLS_SW_TX, + TLS_NUM_CONFIG, +}; + +static struct proto tls_prots[TLS_NUM_CONFIG]; + +static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) +{ + sk->sk_prot = &tls_prots[ctx->tx_conf]; +} int wait_on_pending_writer(struct sock *sk, long *timeo) { @@ -216,6 +226,12 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) void (*sk_proto_close)(struct sock *sk, long timeout); lock_sock(sk); + sk_proto_close = ctx->sk_proto_close; + + if (ctx->tx_conf == TLS_BASE_TX) { + kfree(ctx); + goto skip_tx_cleanup; + } if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) tls_handle_open_record(sk, 0); @@ -232,13 +248,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) sg++; } } - ctx->free_resources(sk); + kfree(ctx->rec_seq); kfree(ctx->iv); - sk_proto_close = ctx->sk_proto_close; - kfree(ctx); + if (ctx->tx_conf == TLS_SW_TX) + tls_sw_free_tx_resources(sk); +skip_tx_cleanup: release_sock(sk); sk_proto_close(sk, timeout); } @@ -338,46 +355,41 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, unsigned int optlen) { - struct tls_crypto_info *crypto_info, tmp_crypto_info; + struct tls_crypto_info *crypto_info; struct tls_context *ctx = tls_get_ctx(sk); - struct proto *prot = NULL; int rc = 0; + int tx_conf; if (!optval || (optlen < sizeof(*crypto_info))) { rc = -EINVAL; goto out; } - rc = copy_from_user(&tmp_crypto_info, optval, sizeof(*crypto_info)); + crypto_info = &ctx->crypto_send; + /* Currently we don't support set crypto info more than one time */ + if (TLS_CRYPTO_INFO_READY(crypto_info)) + goto out; + + rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info)); if (rc) { rc = -EFAULT; goto out; } /* check version */ - if (tmp_crypto_info.version != TLS_1_2_VERSION) { + if (crypto_info->version != TLS_1_2_VERSION) { rc = -ENOTSUPP; - goto out; + goto err_crypto_info; } - /* get user crypto info */ - crypto_info = &ctx->crypto_send; - - /* Currently we don't support set crypto info more than one time */ - if (TLS_CRYPTO_INFO_READY(crypto_info)) - goto out; - - switch (tmp_crypto_info.cipher_type) { + switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: { if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { rc = -EINVAL; goto out; } - rc = copy_from_user( - crypto_info, - optval, - sizeof(struct tls12_crypto_info_aes_gcm_128)); - + rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), + optlen - sizeof(*crypto_info)); if (rc) { rc = -EFAULT; goto err_crypto_info; @@ -389,18 +401,16 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, goto out; } - ctx->sk_write_space = sk->sk_write_space; - sk->sk_write_space = tls_write_space; - - ctx->sk_proto_close = sk->sk_prot->close; - /* currently SW is default, we will have ethtool in future */ rc = tls_set_sw_offload(sk, ctx); - prot = &tls_sw_prot; + tx_conf = TLS_SW_TX; if (rc) goto err_crypto_info; - sk->sk_prot = prot; + ctx->tx_conf = tx_conf; + update_sk_prot(sk, ctx); + ctx->sk_write_space = sk->sk_write_space; + sk->sk_write_space = tls_write_space; goto out; err_crypto_info: @@ -453,7 +463,10 @@ static int tls_init(struct sock *sk) icsk->icsk_ulp_data = ctx; ctx->setsockopt = sk->sk_prot->setsockopt; ctx->getsockopt = sk->sk_prot->getsockopt; - sk->sk_prot = &tls_base_prot; + ctx->sk_proto_close = sk->sk_prot->close; + + ctx->tx_conf = TLS_BASE_TX; + update_sk_prot(sk, ctx); out: return rc; } @@ -464,16 +477,21 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .init = tls_init, }; +static void build_protos(struct proto *prot, struct proto *base) +{ + prot[TLS_BASE_TX] = *base; + prot[TLS_BASE_TX].setsockopt = tls_setsockopt; + prot[TLS_BASE_TX].getsockopt = tls_getsockopt; + prot[TLS_BASE_TX].close = tls_sk_proto_close; + + prot[TLS_SW_TX] = prot[TLS_BASE_TX]; + prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; + prot[TLS_SW_TX].sendpage = tls_sw_sendpage; +} + static int __init tls_register(void) { - tls_base_prot = tcp_prot; - tls_base_prot.setsockopt = tls_setsockopt; - tls_base_prot.getsockopt = tls_getsockopt; - - tls_sw_prot = tls_base_prot; - tls_sw_prot.sendmsg = tls_sw_sendmsg; - tls_sw_prot.sendpage = tls_sw_sendpage; - tls_sw_prot.close = tls_sk_proto_close; + build_protos(tls_prots, &tcp_prot); tcp_register_ulp(&tcp_tls_ulp_ops); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 7d80040a37b6d54901c05964a2f8a9ab0851a2cb..73d19210dd497193ff545ee15305113e6090f731 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -39,22 +39,6 @@ #include -static inline void tls_make_aad(int recv, - char *buf, - size_t size, - char *record_sequence, - int record_sequence_size, - unsigned char record_type) -{ - memcpy(buf, record_sequence, record_sequence_size); - - buf[8] = record_type; - buf[9] = TLS_1_2_VERSION_MAJOR; - buf[10] = TLS_1_2_VERSION_MINOR; - buf[11] = size >> 8; - buf[12] = size & 0xFF; -} - static void trim_sg(struct sock *sk, struct scatterlist *sg, int *sg_num_elem, unsigned int *sg_size, int target_size) { @@ -219,7 +203,7 @@ static int tls_do_encryption(struct tls_context *tls_ctx, struct aead_request *aead_req; int rc; - aead_req = kmalloc(req_size, flags); + aead_req = kzalloc(req_size, flags); if (!aead_req) return -ENOMEM; @@ -249,7 +233,7 @@ static int tls_push_record(struct sock *sk, int flags, sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); - tls_make_aad(0, ctx->aad_space, ctx->sg_plaintext_size, + tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, tls_ctx->rec_seq, tls_ctx->rec_seq_size, record_type); @@ -639,7 +623,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, return ret; } -static void tls_sw_free_resources(struct sock *sk) +void tls_sw_free_tx_resources(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); @@ -650,6 +634,7 @@ static void tls_sw_free_resources(struct sock *sk) tls_free_both_sg(sk); kfree(ctx); + kfree(tls_ctx); } int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) @@ -679,7 +664,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) } ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; - ctx->free_resources = tls_sw_free_resources; crypto_info = &ctx->crypto_send; switch (crypto_info->cipher_type) {