diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 1f7e4e4e6f8efee40262024c493c225e197d2ae8..998bed505530df03e033719e88c5d38405d34ac6 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -282,6 +282,22 @@ void vhost_poll_queue(struct vhost_poll *poll) } EXPORT_SYMBOL_GPL(vhost_poll_queue); +static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq) +{ + int j; + + for (j = 0; j < VHOST_NUM_ADDRS; j++) + vq->meta_iotlb[j] = NULL; +} + +static void vhost_vq_meta_reset(struct vhost_dev *d) +{ + int i; + + for (i = 0; i < d->nvqs; ++i) + __vhost_vq_meta_reset(d->vqs[i]); +} + static void vhost_vq_reset(struct vhost_dev *dev, struct vhost_virtqueue *vq) { @@ -312,6 +328,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->busyloop_timeout = 0; vq->umem = NULL; vq->iotlb = NULL; + __vhost_vq_meta_reset(vq); } static int vhost_worker(void *data) @@ -691,6 +708,18 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, return 1; } +static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq, + u64 addr, unsigned int size, + int type) +{ + const struct vhost_umem_node *node = vq->meta_iotlb[type]; + + if (!node) + return NULL; + + return (void *)(uintptr_t)(node->userspace_addr + addr - node->start); +} + /* Can we switch to this memory table? */ /* Caller should have device mutex but not vq mutex */ static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, @@ -733,8 +762,14 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, * could be access through iotlb. So -EAGAIN should * not happen in this case. */ - /* TODO: more fast path */ struct iov_iter t; + void __user *uaddr = vhost_vq_meta_fetch(vq, + (u64)(uintptr_t)to, size, + VHOST_ADDR_DESC); + + if (uaddr) + return __copy_to_user(uaddr, from, size); + ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov, ARRAY_SIZE(vq->iotlb_iov), VHOST_ACCESS_WO); @@ -762,8 +797,14 @@ static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, * could be access through iotlb. So -EAGAIN should * not happen in this case. */ - /* TODO: more fast path */ + void __user *uaddr = vhost_vq_meta_fetch(vq, + (u64)(uintptr_t)from, size, + VHOST_ADDR_DESC); struct iov_iter f; + + if (uaddr) + return __copy_from_user(to, uaddr, size); + ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov, ARRAY_SIZE(vq->iotlb_iov), VHOST_ACCESS_RO); @@ -783,17 +824,12 @@ static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, return ret; } -static void __user *__vhost_get_user(struct vhost_virtqueue *vq, - void __user *addr, unsigned size) +static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq, + void __user *addr, unsigned int size, + int type) { int ret; - /* This function should be called after iotlb - * prefetch, which means we're sure that vq - * could be access through iotlb. So -EAGAIN should - * not happen in this case. - */ - /* TODO: more fast path */ ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov, ARRAY_SIZE(vq->iotlb_iov), VHOST_ACCESS_RO); @@ -814,14 +850,32 @@ static void __user *__vhost_get_user(struct vhost_virtqueue *vq, return vq->iotlb_iov[0].iov_base; } -#define vhost_put_user(vq, x, ptr) \ +/* This function should be called after iotlb + * prefetch, which means we're sure that vq + * could be access through iotlb. So -EAGAIN should + * not happen in this case. + */ +static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, + void *addr, unsigned int size, + int type) +{ + void __user *uaddr = vhost_vq_meta_fetch(vq, + (u64)(uintptr_t)addr, size, type); + if (uaddr) + return uaddr; + + return __vhost_get_user_slow(vq, addr, size, type); +} + +#define vhost_put_user(vq, x, ptr) \ ({ \ int ret = -EFAULT; \ if (!vq->iotlb) { \ ret = __put_user(x, ptr); \ } else { \ __typeof__(ptr) to = \ - (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \ + (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ + sizeof(*ptr), VHOST_ADDR_USED); \ if (to != NULL) \ ret = __put_user(x, to); \ else \ @@ -830,14 +884,16 @@ static void __user *__vhost_get_user(struct vhost_virtqueue *vq, ret; \ }) -#define vhost_get_user(vq, x, ptr) \ +#define vhost_get_user(vq, x, ptr, type) \ ({ \ int ret; \ if (!vq->iotlb) { \ ret = __get_user(x, ptr); \ } else { \ __typeof__(ptr) from = \ - (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \ + (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ + sizeof(*ptr), \ + type); \ if (from != NULL) \ ret = __get_user(x, from); \ else \ @@ -846,6 +902,12 @@ static void __user *__vhost_get_user(struct vhost_virtqueue *vq, ret; \ }) +#define vhost_get_avail(vq, x, ptr) \ + vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL) + +#define vhost_get_used(vq, x, ptr) \ + vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) + static void vhost_dev_lock_vqs(struct vhost_dev *d) { int i = 0; @@ -951,6 +1013,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, ret = -EFAULT; break; } + vhost_vq_meta_reset(dev); if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, msg->iova + msg->size - 1, msg->uaddr, msg->perm)) { @@ -960,6 +1023,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, vhost_iotlb_notify_vq(dev, msg); break; case VHOST_IOTLB_INVALIDATE: + vhost_vq_meta_reset(dev); vhost_del_umem_range(dev->iotlb, msg->iova, msg->iova + msg->size - 1); break; @@ -1103,12 +1167,26 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, sizeof *used + num * sizeof *used->ring + s); } +static void vhost_vq_meta_update(struct vhost_virtqueue *vq, + const struct vhost_umem_node *node, + int type) +{ + int access = (type == VHOST_ADDR_USED) ? + VHOST_ACCESS_WO : VHOST_ACCESS_RO; + + if (likely(node->perm & access)) + vq->meta_iotlb[type] = node; +} + static int iotlb_access_ok(struct vhost_virtqueue *vq, - int access, u64 addr, u64 len) + int access, u64 addr, u64 len, int type) { const struct vhost_umem_node *node; struct vhost_umem *umem = vq->iotlb; - u64 s = 0, size; + u64 s = 0, size, orig_addr = addr; + + if (vhost_vq_meta_fetch(vq, addr, len, type)) + return true; while (len > s) { node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, @@ -1125,6 +1203,10 @@ static int iotlb_access_ok(struct vhost_virtqueue *vq, } size = node->size - addr + node->start; + + if (orig_addr == addr && size >= len) + vhost_vq_meta_update(vq, node, type); + s += size; addr += size; } @@ -1141,13 +1223,15 @@ int vq_iotlb_prefetch(struct vhost_virtqueue *vq) return 1; return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, - num * sizeof *vq->desc) && + num * sizeof(*vq->desc), VHOST_ADDR_DESC) && iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, sizeof *vq->avail + - num * sizeof *vq->avail->ring + s) && + num * sizeof(*vq->avail->ring) + s, + VHOST_ADDR_AVAIL) && iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, sizeof *vq->used + - num * sizeof *vq->used->ring + s); + num * sizeof(*vq->used->ring) + s, + VHOST_ADDR_USED); } EXPORT_SYMBOL_GPL(vq_iotlb_prefetch); @@ -1728,7 +1812,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) r = -EFAULT; goto err; } - r = vhost_get_user(vq, last_used_idx, &vq->used->idx); + r = vhost_get_used(vq, last_used_idx, &vq->used->idx); if (r) { vq_err(vq, "Can't access used idx at %p\n", &vq->used->idx); @@ -1932,7 +2016,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, last_avail_idx = vq->last_avail_idx; if (vq->avail_idx == vq->last_avail_idx) { - if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) { + if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) { vq_err(vq, "Failed to access avail idx at %p\n", &vq->avail->idx); return -EFAULT; @@ -1959,7 +2043,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, /* Grab the next descriptor number they're advertising, and increment * the index we've seen. */ - if (unlikely(vhost_get_user(vq, ring_head, + if (unlikely(vhost_get_avail(vq, ring_head, &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { vq_err(vq, "Failed to read head: idx %d address %p\n", last_avail_idx, @@ -2175,7 +2259,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) * with the barrier that the Guest executes when enabling * interrupts. */ smp_mb(); - if (vhost_get_user(vq, flags, &vq->avail->flags)) { + if (vhost_get_avail(vq, flags, &vq->avail->flags)) { vq_err(vq, "Failed to get flags"); return true; } @@ -2202,7 +2286,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) * interrupts. */ smp_mb(); - if (vhost_get_user(vq, event, vhost_used_event(vq))) { + if (vhost_get_avail(vq, event, vhost_used_event(vq))) { vq_err(vq, "Failed to get used event idx"); return true; } @@ -2246,7 +2330,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) __virtio16 avail_idx; int r; - r = vhost_get_user(vq, avail_idx, &vq->avail->idx); + r = vhost_get_avail(vq, avail_idx, &vq->avail->idx); if (r) return false; @@ -2281,7 +2365,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) /* They could have slipped one in as we were doing that: make * sure it's written, then check again. */ smp_mb(); - r = vhost_get_user(vq, avail_idx, &vq->avail->idx); + r = vhost_get_avail(vq, avail_idx, &vq->avail->idx); if (r) { vq_err(vq, "Failed to check avail idx at %p: %d\n", &vq->avail->idx, r); diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index a9cbbb148f460e52ef8a051770fd32f616696f14..f55671d53f28fed413afa54af957add8ed65c972 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -76,6 +76,13 @@ struct vhost_umem { int numem; }; +enum vhost_uaddr_type { + VHOST_ADDR_DESC = 0, + VHOST_ADDR_AVAIL = 1, + VHOST_ADDR_USED = 2, + VHOST_NUM_ADDRS = 3, +}; + /* The virtqueue structure describes a queue attached to a device. */ struct vhost_virtqueue { struct vhost_dev *dev; @@ -86,6 +93,7 @@ struct vhost_virtqueue { struct vring_desc __user *desc; struct vring_avail __user *avail; struct vring_used __user *used; + const struct vhost_umem_node *meta_iotlb[VHOST_NUM_ADDRS]; struct file *kick; struct file *call; struct file *error;