diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index f0fd52cdfadc05b6fcd57b9efff0eddee4173d2a..70ac60437d174411f891f20f8b477b20b74e8e8e 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -703,6 +703,10 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) vhost_net_disable_vq(n, vq); rcu_assign_pointer(vq->private_data, sock); vhost_net_enable_vq(n, vq); + + r = vhost_init_used(vq); + if (r) + goto err_vq; } mutex_unlock(&vq->mutex); diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c index 734e1d74ad805a1547ed867dd8363f09bcd9e2f8..fc9a1d75281f33d57be305b3c724195c578334de 100644 --- a/drivers/vhost/test.c +++ b/drivers/vhost/test.c @@ -195,8 +195,13 @@ static long vhost_test_run(struct vhost_test *n, int test) lockdep_is_held(&vq->mutex)); rcu_assign_pointer(vq->private_data, priv); + r = vhost_init_used(&n->vqs[index]); + mutex_unlock(&vq->mutex); + if (r) + goto err; + if (oldpriv) { vhost_test_flush_vq(n, index); } diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 5ef2f62becf4b94c69082b4dc1b1a2f6eacd84ac..9a108038fe527393f6ac2b04a2104916b1518290 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -629,15 +629,17 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) return 0; } -static int init_used(struct vhost_virtqueue *vq, - struct vring_used __user *used) +int vhost_init_used(struct vhost_virtqueue *vq) { - int r = put_user(vq->used_flags, &used->flags); + int r; + if (!vq->private_data) + return 0; + r = put_user(vq->used_flags, &vq->used->flags); if (r) return r; vq->signalled_used_valid = false; - return get_user(vq->last_used_idx, &used->idx); + return get_user(vq->last_used_idx, &vq->used->idx); } static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp) @@ -752,10 +754,6 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp) } } - r = init_used(vq, (struct vring_used __user *)(unsigned long) - a.used_user_addr); - if (r) - break; vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); vq->desc = (void __user *)(unsigned long)a.desc_user_addr; vq->avail = (void __user *)(unsigned long)a.avail_user_addr; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 1544b782529b5803730ac6e447ac698caa7c7239..14c9abf0d80025fd863460c1bd935d5b49f88c26 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -174,6 +174,7 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, struct vhost_log *log, unsigned int *log_num); void vhost_discard_vq_desc(struct vhost_virtqueue *, int n); +int vhost_init_used(struct vhost_virtqueue *); int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads, unsigned count);