diff --git a/include/net/tls.h b/include/net/tls.h index 95a8c60b36bee88fd07b04870c03019fb1d4a066..8c56809eb384a2de8c4de46bd6ce97c9b943354a 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -116,6 +116,37 @@ struct tls_sw_context_rx { bool decrypted; }; +struct tls_record_info { + struct list_head list; + u32 end_seq; + int len; + int num_frags; + skb_frag_t frags[MAX_SKB_FRAGS]; +}; + +struct tls_offload_context { + struct crypto_aead *aead_send; + spinlock_t lock; /* protects records list */ + struct list_head records_list; + struct tls_record_info *open_record; + struct tls_record_info *retransmit_hint; + u64 hint_record_sn; + u64 unacked_record_sn; + + struct scatterlist sg_tx_data[MAX_SKB_FRAGS]; + void (*sk_destruct)(struct sock *sk); + u8 driver_state[]; + /* The TLS layer reserves room for driver specific state + * Currently the belief is that there is not enough + * driver specific state to justify another layer of indirection + */ +#define TLS_DRIVER_STATE_SIZE (max_t(size_t, 8, sizeof(void *))) +}; + +#define TLS_OFFLOAD_CONTEXT_SIZE \ + (ALIGN(sizeof(struct tls_offload_context), sizeof(void *)) + \ + TLS_DRIVER_STATE_SIZE) + enum { TLS_PENDING_CLOSED_RECORD }; @@ -195,9 +226,28 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct pipe_inode_info *pipe, size_t len, unsigned int flags); -void tls_sk_destruct(struct sock *sk, struct tls_context *ctx); -void tls_icsk_clean_acked(struct sock *sk); +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_device_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +void tls_device_sk_destruct(struct sock *sk); +void tls_device_init(void); +void tls_device_cleanup(void); + +struct tls_record_info *tls_get_record(struct tls_offload_context *context, + u32 seq, u64 *p_record_sn); + +static inline bool tls_record_is_start_marker(struct tls_record_info *rec) +{ + return rec->len == 0; +} + +static inline u32 tls_record_start_seq(struct tls_record_info *rec) +{ + return rec->end_seq - rec->len; +} +void tls_sk_destruct(struct sock *sk, struct tls_context *ctx); int tls_push_sg(struct sock *sk, struct tls_context *ctx, struct scatterlist *sg, u16 first_offset, int flags); @@ -234,6 +284,13 @@ static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) return tls_ctx->pending_open_record_frags; } +static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk) +{ + return sk_fullsock(sk) && + /* matches smp_store_release in tls_set_device_offload */ + smp_load_acquire(&sk->sk_destruct) == &tls_device_sk_destruct; +} + static inline void tls_err_abort(struct sock *sk, int err) { sk->sk_err = err; @@ -329,4 +386,12 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, void tls_register_device(struct tls_device *device); void tls_unregister_device(struct tls_device *device); +struct sk_buff *tls_validate_xmit_skb(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb); + +int tls_sw_fallback_init(struct sock *sk, + struct tls_offload_context *offload_ctx, + struct tls_crypto_info *crypto_info); + #endif /* _TLS_OFFLOAD_H */ diff --git a/net/tls/Kconfig b/net/tls/Kconfig index 89b8745a986f06c00bd35f7a4fc6fff47c25120a..73f05ece53d0c955df2bac6d043e1e354d8fb7dc 100644 --- a/net/tls/Kconfig +++ b/net/tls/Kconfig @@ -14,3 +14,13 @@ config TLS encryption handling of the TLS protocol to be done in-kernel. If unsure, say N. + +config TLS_DEVICE + bool "Transport Layer Security HW offload" + depends on TLS + select SOCK_VALIDATE_XMIT + default n + help + Enable kernel support for HW offload of the TLS protocol. + + If unsure, say N. diff --git a/net/tls/Makefile b/net/tls/Makefile index a930fd1c4f7b88651d5686355cbea8ba55942941..4d6b728a67d0c122a41c1c0d5d5840598ba56472 100644 --- a/net/tls/Makefile +++ b/net/tls/Makefile @@ -5,3 +5,5 @@ obj-$(CONFIG_TLS) += tls.o tls-y := tls_main.o tls_sw.o + +tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c new file mode 100644 index 0000000000000000000000000000000000000000..ac45d6224e416d0018c73de512763e2ec3996af9 --- /dev/null +++ b/net/tls/tls_device.c @@ -0,0 +1,764 @@ +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* device_offload_lock is used to synchronize tls_dev_add + * against NETDEV_DOWN notifications. + */ +static DECLARE_RWSEM(device_offload_lock); + +static void tls_device_gc_task(struct work_struct *work); + +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task); +static LIST_HEAD(tls_device_gc_list); +static LIST_HEAD(tls_device_list); +static DEFINE_SPINLOCK(tls_device_lock); + +static void tls_device_free_ctx(struct tls_context *ctx) +{ + struct tls_offload_context *offload_ctx = tls_offload_ctx(ctx); + + kfree(offload_ctx); + kfree(ctx); +} + +static void tls_device_gc_task(struct work_struct *work) +{ + struct tls_context *ctx, *tmp; + unsigned long flags; + LIST_HEAD(gc_list); + + spin_lock_irqsave(&tls_device_lock, flags); + list_splice_init(&tls_device_gc_list, &gc_list); + spin_unlock_irqrestore(&tls_device_lock, flags); + + list_for_each_entry_safe(ctx, tmp, &gc_list, list) { + struct net_device *netdev = ctx->netdev; + + if (netdev) { + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + dev_put(netdev); + } + + list_del(&ctx->list); + tls_device_free_ctx(ctx); + } +} + +static void tls_device_queue_ctx_destruction(struct tls_context *ctx) +{ + unsigned long flags; + + spin_lock_irqsave(&tls_device_lock, flags); + list_move_tail(&ctx->list, &tls_device_gc_list); + + /* schedule_work inside the spinlock + * to make sure tls_device_down waits for that work. + */ + schedule_work(&tls_device_gc_work); + + spin_unlock_irqrestore(&tls_device_lock, flags); +} + +/* We assume that the socket is already connected */ +static struct net_device *get_netdev_for_sock(struct sock *sk) +{ + struct dst_entry *dst = sk_dst_get(sk); + struct net_device *netdev = NULL; + + if (likely(dst)) { + netdev = dst->dev; + dev_hold(netdev); + } + + dst_release(dst); + + return netdev; +} + +static void destroy_record(struct tls_record_info *record) +{ + int nr_frags = record->num_frags; + skb_frag_t *frag; + + while (nr_frags-- > 0) { + frag = &record->frags[nr_frags]; + __skb_frag_unref(frag); + } + kfree(record); +} + +static void delete_all_records(struct tls_offload_context *offload_ctx) +{ + struct tls_record_info *info, *temp; + + list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) { + list_del(&info->list); + destroy_record(info); + } + + offload_ctx->retransmit_hint = NULL; +} + +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_record_info *info, *temp; + struct tls_offload_context *ctx; + u64 deleted_records = 0; + unsigned long flags; + + if (!tls_ctx) + return; + + ctx = tls_offload_ctx(tls_ctx); + + spin_lock_irqsave(&ctx->lock, flags); + info = ctx->retransmit_hint; + if (info && !before(acked_seq, info->end_seq)) { + ctx->retransmit_hint = NULL; + list_del(&info->list); + destroy_record(info); + deleted_records++; + } + + list_for_each_entry_safe(info, temp, &ctx->records_list, list) { + if (before(acked_seq, info->end_seq)) + break; + list_del(&info->list); + + destroy_record(info); + deleted_records++; + } + + ctx->unacked_record_sn += deleted_records; + spin_unlock_irqrestore(&ctx->lock, flags); +} + +/* At this point, there should be no references on this + * socket and no in-flight SKBs associated with this + * socket, so it is safe to free all the resources. + */ +void tls_device_sk_destruct(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + + if (ctx->open_record) + destroy_record(ctx->open_record); + + delete_all_records(ctx); + crypto_free_aead(ctx->aead_send); + ctx->sk_destruct(sk); + clean_acked_data_disable(inet_csk(sk)); + + if (refcount_dec_and_test(&tls_ctx->refcount)) + tls_device_queue_ctx_destruction(tls_ctx); +} +EXPORT_SYMBOL(tls_device_sk_destruct); + +static void tls_append_frag(struct tls_record_info *record, + struct page_frag *pfrag, + int size) +{ + skb_frag_t *frag; + + frag = &record->frags[record->num_frags - 1]; + if (frag->page.p == pfrag->page && + frag->page_offset + frag->size == pfrag->offset) { + frag->size += size; + } else { + ++frag; + frag->page.p = pfrag->page; + frag->page_offset = pfrag->offset; + frag->size = size; + ++record->num_frags; + get_page(pfrag->page); + } + + pfrag->offset += size; + record->len += size; +} + +static int tls_push_record(struct sock *sk, + struct tls_context *ctx, + struct tls_offload_context *offload_ctx, + struct tls_record_info *record, + struct page_frag *pfrag, + int flags, + unsigned char record_type) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct page_frag dummy_tag_frag; + skb_frag_t *frag; + int i; + + /* fill prepend */ + frag = &record->frags[0]; + tls_fill_prepend(ctx, + skb_frag_address(frag), + record->len - ctx->tx.prepend_size, + record_type); + + /* HW doesn't care about the data in the tag, because it fills it. */ + dummy_tag_frag.page = skb_frag_page(frag); + dummy_tag_frag.offset = 0; + + tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size); + record->end_seq = tp->write_seq + record->len; + spin_lock_irq(&offload_ctx->lock); + list_add_tail(&record->list, &offload_ctx->records_list); + spin_unlock_irq(&offload_ctx->lock); + offload_ctx->open_record = NULL; + set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); + tls_advance_record_sn(sk, &ctx->tx); + + for (i = 0; i < record->num_frags; i++) { + frag = &record->frags[i]; + sg_unmark_end(&offload_ctx->sg_tx_data[i]); + sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag), + frag->size, frag->page_offset); + sk_mem_charge(sk, frag->size); + get_page(skb_frag_page(frag)); + } + sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]); + + /* all ready, send */ + return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); +} + +static int tls_create_new_record(struct tls_offload_context *offload_ctx, + struct page_frag *pfrag, + size_t prepend_size) +{ + struct tls_record_info *record; + skb_frag_t *frag; + + record = kmalloc(sizeof(*record), GFP_KERNEL); + if (!record) + return -ENOMEM; + + frag = &record->frags[0]; + __skb_frag_set_page(frag, pfrag->page); + frag->page_offset = pfrag->offset; + skb_frag_size_set(frag, prepend_size); + + get_page(pfrag->page); + pfrag->offset += prepend_size; + + record->num_frags = 1; + record->len = prepend_size; + offload_ctx->open_record = record; + return 0; +} + +static int tls_do_allocation(struct sock *sk, + struct tls_offload_context *offload_ctx, + struct page_frag *pfrag, + size_t prepend_size) +{ + int ret; + + if (!offload_ctx->open_record) { + if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, + sk->sk_allocation))) { + sk->sk_prot->enter_memory_pressure(sk); + sk_stream_moderate_sndbuf(sk); + return -ENOMEM; + } + + ret = tls_create_new_record(offload_ctx, pfrag, prepend_size); + if (ret) + return ret; + + if (pfrag->size > pfrag->offset) + return 0; + } + + if (!sk_page_frag_refill(sk, pfrag)) + return -ENOMEM; + + return 0; +} + +static int tls_push_data(struct sock *sk, + struct iov_iter *msg_iter, + size_t size, int flags, + unsigned char record_type) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; + int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); + struct tls_record_info *record = ctx->open_record; + struct page_frag *pfrag; + size_t orig_size = size; + u32 max_open_record_len; + int copy, rc = 0; + bool done = false; + long timeo; + + if (flags & + ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST)) + return -ENOTSUPP; + + if (sk->sk_err) + return -sk->sk_err; + + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); + if (rc < 0) + return rc; + + pfrag = sk_page_frag(sk); + + /* TLS_HEADER_SIZE is not counted as part of the TLS record, and + * we need to leave room for an authentication tag. + */ + max_open_record_len = TLS_MAX_PAYLOAD_SIZE + + tls_ctx->tx.prepend_size; + do { + rc = tls_do_allocation(sk, ctx, pfrag, + tls_ctx->tx.prepend_size); + if (rc) { + rc = sk_stream_wait_memory(sk, &timeo); + if (!rc) + continue; + + record = ctx->open_record; + if (!record) + break; +handle_error: + if (record_type != TLS_RECORD_TYPE_DATA) { + /* avoid sending partial + * record with type != + * application_data + */ + size = orig_size; + destroy_record(record); + ctx->open_record = NULL; + } else if (record->len > tls_ctx->tx.prepend_size) { + goto last_record; + } + + break; + } + + record = ctx->open_record; + copy = min_t(size_t, size, (pfrag->size - pfrag->offset)); + copy = min_t(size_t, copy, (max_open_record_len - record->len)); + + if (copy_from_iter_nocache(page_address(pfrag->page) + + pfrag->offset, + copy, msg_iter) != copy) { + rc = -EFAULT; + goto handle_error; + } + tls_append_frag(record, pfrag, copy); + + size -= copy; + if (!size) { +last_record: + tls_push_record_flags = flags; + if (more) { + tls_ctx->pending_open_record_frags = + record->num_frags; + break; + } + + done = true; + } + + if (done || record->len >= max_open_record_len || + (record->num_frags >= MAX_SKB_FRAGS - 1)) { + rc = tls_push_record(sk, + tls_ctx, + ctx, + record, + pfrag, + tls_push_record_flags, + record_type); + if (rc < 0) + break; + } + } while (!done); + + if (orig_size - size > 0) + rc = orig_size - size; + + return rc; +} + +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + unsigned char record_type = TLS_RECORD_TYPE_DATA; + int rc; + + lock_sock(sk); + + if (unlikely(msg->msg_controllen)) { + rc = tls_proccess_cmsg(sk, msg, &record_type); + if (rc) + goto out; + } + + rc = tls_push_data(sk, &msg->msg_iter, size, + msg->msg_flags, record_type); + +out: + release_sock(sk); + return rc; +} + +int tls_device_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + struct iov_iter msg_iter; + char *kaddr = kmap(page); + struct kvec iov; + int rc; + + if (flags & MSG_SENDPAGE_NOTLAST) + flags |= MSG_MORE; + + lock_sock(sk); + + if (flags & MSG_OOB) { + rc = -ENOTSUPP; + goto out; + } + + iov.iov_base = kaddr + offset; + iov.iov_len = size; + iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size); + rc = tls_push_data(sk, &msg_iter, size, + flags, TLS_RECORD_TYPE_DATA); + kunmap(page); + +out: + release_sock(sk); + return rc; +} + +struct tls_record_info *tls_get_record(struct tls_offload_context *context, + u32 seq, u64 *p_record_sn) +{ + u64 record_sn = context->hint_record_sn; + struct tls_record_info *info; + + info = context->retransmit_hint; + if (!info || + before(seq, info->end_seq - info->len)) { + /* if retransmit_hint is irrelevant start + * from the beggining of the list + */ + info = list_first_entry(&context->records_list, + struct tls_record_info, list); + record_sn = context->unacked_record_sn; + } + + list_for_each_entry_from(info, &context->records_list, list) { + if (before(seq, info->end_seq)) { + if (!context->retransmit_hint || + after(info->end_seq, + context->retransmit_hint->end_seq)) { + context->hint_record_sn = record_sn; + context->retransmit_hint = info; + } + *p_record_sn = record_sn; + return info; + } + record_sn++; + } + + return NULL; +} +EXPORT_SYMBOL(tls_get_record); + +static int tls_device_push_pending_record(struct sock *sk, int flags) +{ + struct iov_iter msg_iter; + + iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0); + return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); +} + +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +{ + u16 nonce_size, tag_size, iv_size, rec_seq_size; + struct tls_record_info *start_marker_record; + struct tls_offload_context *offload_ctx; + struct tls_crypto_info *crypto_info; + struct net_device *netdev; + char *iv, *rec_seq; + struct sk_buff *skb; + int rc = -EINVAL; + __be64 rcd_sn; + + if (!ctx) + goto out; + + if (ctx->priv_ctx_tx) { + rc = -EEXIST; + goto out; + } + + start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); + if (!start_marker_record) { + rc = -ENOMEM; + goto out; + } + + offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL); + if (!offload_ctx) { + rc = -ENOMEM; + goto free_marker_record; + } + + crypto_info = &ctx->crypto_send; + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: + nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; + rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; + rec_seq = + ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; + break; + default: + rc = -EINVAL; + goto free_offload_ctx; + } + + ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; + ctx->tx.tag_size = tag_size; + ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; + ctx->tx.iv_size = iv_size; + ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + GFP_KERNEL); + if (!ctx->tx.iv) { + rc = -ENOMEM; + goto free_offload_ctx; + } + + memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); + + ctx->tx.rec_seq_size = rec_seq_size; + ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + if (!ctx->tx.rec_seq) { + rc = -ENOMEM; + goto free_iv; + } + memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size); + + rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info); + if (rc) + goto free_rec_seq; + + /* start at rec_seq - 1 to account for the start marker record */ + memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; + + start_marker_record->end_seq = tcp_sk(sk)->write_seq; + start_marker_record->len = 0; + start_marker_record->num_frags = 0; + + INIT_LIST_HEAD(&offload_ctx->records_list); + list_add_tail(&start_marker_record->list, &offload_ctx->records_list); + spin_lock_init(&offload_ctx->lock); + + clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); + ctx->push_pending_record = tls_device_push_pending_record; + offload_ctx->sk_destruct = sk->sk_destruct; + + /* TLS offload is greatly simplified if we don't send + * SKBs where only part of the payload needs to be encrypted. + * So mark the last skb in the write queue as end of record. + */ + skb = tcp_write_queue_tail(sk); + if (skb) + TCP_SKB_CB(skb)->eor = 1; + + refcount_set(&ctx->refcount, 1); + + /* We support starting offload on multiple sockets + * concurrently, so we only need a read lock here. + * This lock must precede get_netdev_for_sock to prevent races between + * NETDEV_DOWN and setsockopt. + */ + down_read(&device_offload_lock); + netdev = get_netdev_for_sock(sk); + if (!netdev) { + pr_err_ratelimited("%s: netdev not found\n", __func__); + rc = -EINVAL; + goto release_lock; + } + + if (!(netdev->features & NETIF_F_HW_TLS_TX)) { + rc = -ENOTSUPP; + goto release_netdev; + } + + /* Avoid offloading if the device is down + * We don't want to offload new flows after + * the NETDEV_DOWN event + */ + if (!(netdev->flags & IFF_UP)) { + rc = -EINVAL; + goto release_netdev; + } + + ctx->priv_ctx_tx = offload_ctx; + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, + &ctx->crypto_send, + tcp_sk(sk)->write_seq); + if (rc) + goto release_netdev; + + ctx->netdev = netdev; + + spin_lock_irq(&tls_device_lock); + list_add_tail(&ctx->list, &tls_device_list); + spin_unlock_irq(&tls_device_lock); + + sk->sk_validate_xmit_skb = tls_validate_xmit_skb; + /* following this assignment tls_is_sk_tx_device_offloaded + * will return true and the context might be accessed + * by the netdev's xmit function. + */ + smp_store_release(&sk->sk_destruct, + &tls_device_sk_destruct); + up_read(&device_offload_lock); + goto out; + +release_netdev: + dev_put(netdev); +release_lock: + up_read(&device_offload_lock); + clean_acked_data_disable(inet_csk(sk)); + crypto_free_aead(offload_ctx->aead_send); +free_rec_seq: + kfree(ctx->tx.rec_seq); +free_iv: + kfree(ctx->tx.iv); +free_offload_ctx: + kfree(offload_ctx); + ctx->priv_ctx_tx = NULL; +free_marker_record: + kfree(start_marker_record); +out: + return rc; +} + +static int tls_device_down(struct net_device *netdev) +{ + struct tls_context *ctx, *tmp; + unsigned long flags; + LIST_HEAD(list); + + /* Request a write lock to block new offload attempts */ + down_write(&device_offload_lock); + + spin_lock_irqsave(&tls_device_lock, flags); + list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) { + if (ctx->netdev != netdev || + !refcount_inc_not_zero(&ctx->refcount)) + continue; + + list_move(&ctx->list, &list); + } + spin_unlock_irqrestore(&tls_device_lock, flags); + + list_for_each_entry_safe(ctx, tmp, &list, list) { + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + ctx->netdev = NULL; + dev_put(netdev); + list_del_init(&ctx->list); + + if (refcount_dec_and_test(&ctx->refcount)) + tls_device_free_ctx(ctx); + } + + up_write(&device_offload_lock); + + flush_work(&tls_device_gc_work); + + return NOTIFY_DONE; +} + +static int tls_dev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!(dev->features & NETIF_F_HW_TLS_TX)) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_REGISTER: + case NETDEV_FEAT_CHANGE: + if (dev->tlsdev_ops && + dev->tlsdev_ops->tls_dev_add && + dev->tlsdev_ops->tls_dev_del) + return NOTIFY_DONE; + else + return NOTIFY_BAD; + case NETDEV_DOWN: + return tls_device_down(dev); + } + return NOTIFY_DONE; +} + +static struct notifier_block tls_dev_notifier = { + .notifier_call = tls_dev_event, +}; + +void __init tls_device_init(void) +{ + register_netdevice_notifier(&tls_dev_notifier); +} + +void __exit tls_device_cleanup(void) +{ + unregister_netdevice_notifier(&tls_dev_notifier); + flush_work(&tls_device_gc_work); +} diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c new file mode 100644 index 0000000000000000000000000000000000000000..748914abdb604e6434db7d8326e8a8480567e1da --- /dev/null +++ b/net/tls/tls_device_fallback.c @@ -0,0 +1,450 @@ +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include +#include +#include + +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) +{ + struct scatterlist *src = walk->sg; + int diff = walk->offset - src->offset; + + sg_set_page(sg, sg_page(src), + src->length - diff, walk->offset); + + scatterwalk_crypto_chain(sg, sg_next(src), 0, 2); +} + +static int tls_enc_record(struct aead_request *aead_req, + struct crypto_aead *aead, char *aad, + char *iv, __be64 rcd_sn, + struct scatter_walk *in, + struct scatter_walk *out, int *in_len) +{ + unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE]; + struct scatterlist sg_in[3]; + struct scatterlist sg_out[3]; + u16 len; + int rc; + + len = min_t(int, *in_len, ARRAY_SIZE(buf)); + + scatterwalk_copychunks(buf, in, len, 0); + scatterwalk_copychunks(buf, out, len, 1); + + *in_len -= len; + if (!*in_len) + return 0; + + scatterwalk_pagedone(in, 0, 1); + scatterwalk_pagedone(out, 1, 1); + + len = buf[4] | (buf[3] << 8); + len -= TLS_CIPHER_AES_GCM_128_IV_SIZE; + + tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE, + (char *)&rcd_sn, sizeof(rcd_sn), buf[0]); + + memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE, + TLS_CIPHER_AES_GCM_128_IV_SIZE); + + sg_init_table(sg_in, ARRAY_SIZE(sg_in)); + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); + sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE); + chain_to_walk(sg_in + 1, in); + chain_to_walk(sg_out + 1, out); + + *in_len -= len; + if (*in_len < 0) { + *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE; + /* the input buffer doesn't contain the entire record. + * trim len accordingly. The resulting authentication tag + * will contain garbage, but we don't care, so we won't + * include any of it in the output skb + * Note that we assume the output buffer length + * is larger then input buffer length + tag size + */ + if (*in_len < 0) + len += *in_len; + + *in_len = 0; + } + + if (*in_len) { + scatterwalk_copychunks(NULL, in, len, 2); + scatterwalk_pagedone(in, 0, 1); + scatterwalk_copychunks(NULL, out, len, 2); + scatterwalk_pagedone(out, 1, 1); + } + + len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE; + aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv); + + rc = crypto_aead_encrypt(aead_req); + + return rc; +} + +static void tls_init_aead_request(struct aead_request *aead_req, + struct crypto_aead *aead) +{ + aead_request_set_tfm(aead_req, aead); + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); +} + +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, + gfp_t flags) +{ + unsigned int req_size = sizeof(struct aead_request) + + crypto_aead_reqsize(aead); + struct aead_request *aead_req; + + aead_req = kzalloc(req_size, flags); + if (aead_req) + tls_init_aead_request(aead_req, aead); + return aead_req; +} + +static int tls_enc_records(struct aead_request *aead_req, + struct crypto_aead *aead, struct scatterlist *sg_in, + struct scatterlist *sg_out, char *aad, char *iv, + u64 rcd_sn, int len) +{ + struct scatter_walk out, in; + int rc; + + scatterwalk_start(&in, sg_in); + scatterwalk_start(&out, sg_out); + + do { + rc = tls_enc_record(aead_req, aead, aad, iv, + cpu_to_be64(rcd_sn), &in, &out, &len); + rcd_sn++; + + } while (rc == 0 && len); + + scatterwalk_done(&in, 0, 0); + scatterwalk_done(&out, 1, 0); + + return rc; +} + +/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses + * might have been changed by NAT. + */ +static void update_chksum(struct sk_buff *skb, int headln) +{ + struct tcphdr *th = tcp_hdr(skb); + int datalen = skb->len - headln; + const struct ipv6hdr *ipv6h; + const struct iphdr *iph; + + /* We only changed the payload so if we are using partial we don't + * need to update anything. + */ + if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) + return; + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct tcphdr, check); + + if (skb->sk->sk_family == AF_INET6) { + ipv6h = ipv6_hdr(skb); + th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + datalen, IPPROTO_TCP, 0); + } else { + iph = ip_hdr(skb); + th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, + IPPROTO_TCP, 0); + } +} + +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln) +{ + skb_copy_header(nskb, skb); + + skb_put(nskb, skb->len); + memcpy(nskb->data, skb->data, headln); + update_chksum(nskb, headln); + + nskb->destructor = skb->destructor; + nskb->sk = skb->sk; + skb->destructor = NULL; + skb->sk = NULL; + refcount_add(nskb->truesize - skb->truesize, + &nskb->sk->sk_wmem_alloc); +} + +/* This function may be called after the user socket is already + * closed so make sure we don't use anything freed during + * tls_sk_proto_close here + */ + +static int fill_sg_in(struct scatterlist *sg_in, + struct sk_buff *skb, + struct tls_offload_context *ctx, + u64 *rcd_sn, + s32 *sync_size, + int *resync_sgs) +{ + int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + int payload_len = skb->len - tcp_payload_offset; + u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); + struct tls_record_info *record; + unsigned long flags; + int remaining; + int i; + + spin_lock_irqsave(&ctx->lock, flags); + record = tls_get_record(ctx, tcp_seq, rcd_sn); + if (!record) { + spin_unlock_irqrestore(&ctx->lock, flags); + WARN(1, "Record not found for seq %u\n", tcp_seq); + return -EINVAL; + } + + *sync_size = tcp_seq - tls_record_start_seq(record); + if (*sync_size < 0) { + int is_start_marker = tls_record_is_start_marker(record); + + spin_unlock_irqrestore(&ctx->lock, flags); + /* This should only occur if the relevant record was + * already acked. In that case it should be ok + * to drop the packet and avoid retransmission. + * + * There is a corner case where the packet contains + * both an acked and a non-acked record. + * We currently don't handle that case and rely + * on TCP to retranmit a packet that doesn't contain + * already acked payload. + */ + if (!is_start_marker) + *sync_size = 0; + return -EINVAL; + } + + remaining = *sync_size; + for (i = 0; remaining > 0; i++) { + skb_frag_t *frag = &record->frags[i]; + + __skb_frag_ref(frag); + sg_set_page(sg_in + i, skb_frag_page(frag), + skb_frag_size(frag), frag->page_offset); + + remaining -= skb_frag_size(frag); + + if (remaining < 0) + sg_in[i].length += remaining; + } + *resync_sgs = i; + + spin_unlock_irqrestore(&ctx->lock, flags); + if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0) + return -EINVAL; + + return 0; +} + +static void fill_sg_out(struct scatterlist sg_out[3], void *buf, + struct tls_context *tls_ctx, + struct sk_buff *nskb, + int tcp_payload_offset, + int payload_len, + int sync_size, + void *dummy_buf) +{ + sg_set_buf(&sg_out[0], dummy_buf, sync_size); + sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len); + /* Add room for authentication tag produced by crypto */ + dummy_buf += sync_size; + sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE); +} + +static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, + struct scatterlist sg_out[3], + struct scatterlist *sg_in, + struct sk_buff *skb, + s32 sync_size, u64 rcd_sn) +{ + int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + int payload_len = skb->len - tcp_payload_offset; + void *buf, *iv, *aad, *dummy_buf; + struct aead_request *aead_req; + struct sk_buff *nskb = NULL; + int buf_len; + + aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC); + if (!aead_req) + return NULL; + + buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE + + TLS_CIPHER_AES_GCM_128_IV_SIZE + + TLS_AAD_SPACE_SIZE + + sync_size + + TLS_CIPHER_AES_GCM_128_TAG_SIZE; + buf = kmalloc(buf_len, GFP_ATOMIC); + if (!buf) + goto free_req; + + iv = buf; + memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt, + TLS_CIPHER_AES_GCM_128_SALT_SIZE); + aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE + + TLS_CIPHER_AES_GCM_128_IV_SIZE; + dummy_buf = aad + TLS_AAD_SPACE_SIZE; + + nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC); + if (!nskb) + goto free_buf; + + skb_reserve(nskb, skb_headroom(skb)); + + fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset, + payload_len, sync_size, dummy_buf); + + if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv, + rcd_sn, sync_size + payload_len) < 0) + goto free_nskb; + + complete_skb(nskb, skb, tcp_payload_offset); + + /* validate_xmit_skb_list assumes that if the skb wasn't segmented + * nskb->prev will point to the skb itself + */ + nskb->prev = nskb; + +free_buf: + kfree(buf); +free_req: + kfree(aead_req); + return nskb; +free_nskb: + kfree_skb(nskb); + nskb = NULL; + goto free_buf; +} + +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb) +{ + int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb); + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + int payload_len = skb->len - tcp_payload_offset; + struct scatterlist *sg_in, sg_out[3]; + struct sk_buff *nskb = NULL; + int sg_in_max_elements; + int resync_sgs = 0; + s32 sync_size = 0; + u64 rcd_sn; + + /* worst case is: + * MAX_SKB_FRAGS in tls_record_info + * MAX_SKB_FRAGS + 1 in SKB head and frags. + */ + sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1; + + if (!payload_len) + return skb; + + sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC); + if (!sg_in) + goto free_orig; + + sg_init_table(sg_in, sg_in_max_elements); + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); + + if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) { + /* bypass packets before kernel TLS socket option was set */ + if (sync_size < 0 && payload_len <= -sync_size) + nskb = skb_get(skb); + goto put_sg; + } + + nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn); + +put_sg: + while (resync_sgs) + put_page(sg_page(&sg_in[--resync_sgs])); + kfree(sg_in); +free_orig: + kfree_skb(skb); + return nskb; +} + +struct sk_buff *tls_validate_xmit_skb(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb) +{ + if (dev == tls_get_ctx(sk)->netdev) + return skb; + + return tls_sw_fallback(sk, skb); +} + +int tls_sw_fallback_init(struct sock *sk, + struct tls_offload_context *offload_ctx, + struct tls_crypto_info *crypto_info) +{ + const u8 *key; + int rc; + + offload_ctx->aead_send = + crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(offload_ctx->aead_send)) { + rc = PTR_ERR(offload_ctx->aead_send); + pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc); + offload_ctx->aead_send = NULL; + goto err_out; + } + + key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key; + + rc = crypto_aead_setkey(offload_ctx->aead_send, key, + TLS_CIPHER_AES_GCM_128_KEY_SIZE); + if (rc) + goto free_aead; + + rc = crypto_aead_setauthsize(offload_ctx->aead_send, + TLS_CIPHER_AES_GCM_128_TAG_SIZE); + if (rc) + goto free_aead; + + return 0; +free_aead: + crypto_free_aead(offload_ctx->aead_send); +err_out: + return rc; +} diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 545bf34ed599d80de96dceab7053508460a75420..3aafb871a0a82a83d5f3c721a7aa84a8ba0d9810 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -54,6 +54,9 @@ enum { enum { TLS_BASE, TLS_SW, +#ifdef CONFIG_TLS_DEVICE + TLS_HW, +#endif TLS_HW_RECORD, TLS_NUM_CONFIG, }; @@ -280,6 +283,15 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) tls_sw_free_resources_rx(sk); } +#ifdef CONFIG_TLS_DEVICE + if (ctx->tx_conf != TLS_HW) { +#else + { +#endif + kfree(ctx); + ctx = NULL; + } + skip_tx_cleanup: release_sock(sk); sk_proto_close(sk, timeout); @@ -442,8 +454,16 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } if (tx) { - rc = tls_set_sw_offload(sk, ctx, 1); - conf = TLS_SW; +#ifdef CONFIG_TLS_DEVICE + rc = tls_set_device_offload(sk, ctx); + conf = TLS_HW; + if (rc) { +#else + { +#endif + rc = tls_set_sw_offload(sk, ctx, 1); + conf = TLS_SW; + } } else { rc = tls_set_sw_offload(sk, ctx, 0); conf = TLS_SW; @@ -596,6 +616,16 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; +#ifdef CONFIG_TLS_DEVICE + prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; + + prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; + prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; +#endif + prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; @@ -630,7 +660,7 @@ static int tls_init(struct sock *sk) ctx->getsockopt = sk->sk_prot->getsockopt; ctx->sk_proto_close = sk->sk_prot->close; - /* Build IPv6 TLS whenever the address of tcpv6_prot changes */ + /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ if (ip_ver == TLSV6 && unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { mutex_lock(&tcpv6_prot_mutex); @@ -680,6 +710,9 @@ static int __init tls_register(void) tls_sw_proto_ops.poll = tls_sw_poll; tls_sw_proto_ops.splice_read = tls_sw_splice_read; +#ifdef CONFIG_TLS_DEVICE + tls_device_init(); +#endif tcp_register_ulp(&tcp_tls_ulp_ops); return 0; @@ -688,6 +721,9 @@ static int __init tls_register(void) static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); +#ifdef CONFIG_TLS_DEVICE + tls_device_cleanup(); +#endif } module_init(tls_register);