net.c 24.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (C) 2009 Red Hat, Inc.
 * Author: Michael S. Tsirkin <mst@redhat.com>
 *
 * This work is licensed under the terms of the GNU GPL, version 2.
 *
 * virtio-net server in host kernel.
 */

#include <linux/compat.h>
#include <linux/eventfd.h>
#include <linux/vhost.h>
#include <linux/virtio_net.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
15
#include <linux/moduleparam.h>
16 17 18 19
#include <linux/mutex.h>
#include <linux/workqueue.h>
#include <linux/rcupdate.h>
#include <linux/file.h>
20
#include <linux/slab.h>
21 22 23 24 25

#include <linux/net.h>
#include <linux/if_packet.h>
#include <linux/if_arp.h>
#include <linux/if_tun.h>
A
Arnd Bergmann 已提交
26
#include <linux/if_macvlan.h>
B
Basil Gor 已提交
27
#include <linux/if_vlan.h>
28 29 30 31 32

#include <net/sock.h>

#include "vhost.h"

33
static int experimental_zcopytx = 1;
34
module_param(experimental_zcopytx, int, 0444);
35 36
MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
		                       " 1 -Enable; 0 - Disable");
37

38 39 40 41
/* Max number of bytes transferred before requeueing the job.
 * Using this limit prevents one virtqueue from starving others. */
#define VHOST_NET_WEIGHT 0x80000

42 43 44 45
/* MAX number of TX used buffers for outstanding zerocopy */
#define VHOST_MAX_PEND 128
#define VHOST_GOODCOPY_LEN 256

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
/*
 * For transmit, used buffer len is unused; we override it to track buffer
 * status internally; used for zerocopy tx only.
 */
/* Lower device DMA failed */
#define VHOST_DMA_FAILED_LEN	3
/* Lower device DMA done */
#define VHOST_DMA_DONE_LEN	2
/* Lower device DMA in progress */
#define VHOST_DMA_IN_PROGRESS	1
/* Buffer unused */
#define VHOST_DMA_CLEAR_LEN	0

#define VHOST_DMA_IS_DONE(len) ((len) >= VHOST_DMA_DONE_LEN)

61 62 63 64 65 66
enum {
	VHOST_NET_VQ_RX = 0,
	VHOST_NET_VQ_TX = 1,
	VHOST_NET_VQ_MAX = 2,
};

67 68 69 70
struct vhost_net_virtqueue {
	struct vhost_virtqueue vq;
};

71 72
struct vhost_net {
	struct vhost_dev dev;
73
	struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
74
	struct vhost_poll poll[VHOST_NET_VQ_MAX];
75 76 77 78 79 80
	/* Number of TX recently submitted.
	 * Protected by tx vq lock. */
	unsigned tx_packets;
	/* Number of times zerocopy TX recently failed.
	 * Protected by tx vq lock. */
	unsigned tx_zcopy_err;
81 82
	/* Flush in progress. Protected by tx vq lock. */
	bool tx_flush;
83 84
};

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
static void vhost_net_tx_packet(struct vhost_net *net)
{
	++net->tx_packets;
	if (net->tx_packets < 1024)
		return;
	net->tx_packets = 0;
	net->tx_zcopy_err = 0;
}

static void vhost_net_tx_err(struct vhost_net *net)
{
	++net->tx_zcopy_err;
}

static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
{
101 102 103 104 105
	/* TX flush waits for outstanding DMAs to be done.
	 * Don't start new DMAs.
	 */
	return !net->tx_flush &&
		net->tx_packets / 64 >= net->tx_zcopy_err;
106 107
}

108 109 110 111 112 113
static bool vhost_sock_zcopy(struct socket *sock)
{
	return unlikely(experimental_zcopytx) &&
		sock_flag(sock->sk, SOCK_ZEROCOPY);
}

114 115 116 117 118 119
/* Pop first len bytes from iovec. Return number of segments used. */
static int move_iovec_hdr(struct iovec *from, struct iovec *to,
			  size_t len, int iov_count)
{
	int seg = 0;
	size_t size;
K
Krishna Kumar 已提交
120

121 122 123 124 125 126 127 128 129 130 131 132 133
	while (len && seg < iov_count) {
		size = min(from->iov_len, len);
		to->iov_base = from->iov_base;
		to->iov_len = size;
		from->iov_len -= size;
		from->iov_base += size;
		len -= size;
		++from;
		++to;
		++seg;
	}
	return seg;
}
134 135 136 137 138 139
/* Copy iovec entries for len bytes from iovec. */
static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
			   size_t len, int iovcount)
{
	int seg = 0;
	size_t size;
K
Krishna Kumar 已提交
140

141 142 143 144 145 146 147 148 149 150
	while (len && seg < iovcount) {
		size = min(from->iov_len, len);
		to->iov_base = from->iov_base;
		to->iov_len = size;
		len -= size;
		++from;
		++to;
		++seg;
	}
}
151

152 153 154 155 156
/* In case of DMA done not in order in lower device driver for some reason.
 * upend_idx is used to track end of used idx, done_idx is used to track head
 * of used idx. Once lower device DMA done contiguously, we will signal KVM
 * guest used idx.
 */
157 158
static int vhost_zerocopy_signal_used(struct vhost_net *net,
				      struct vhost_virtqueue *vq)
159 160 161 162 163
{
	int i;
	int j = 0;

	for (i = vq->done_idx; i != vq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
164 165
		if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
			vhost_net_tx_err(net);
166 167 168 169 170 171 172 173 174 175 176 177 178
		if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
			vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
			vhost_add_used_and_signal(vq->dev, vq,
						  vq->heads[i].id, 0);
			++j;
		} else
			break;
	}
	if (j)
		vq->done_idx = i;
	return j;
}

179
static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
180 181 182
{
	struct vhost_ubuf_ref *ubufs = ubuf->ctx;
	struct vhost_virtqueue *vq = ubufs->vq;
183 184 185 186 187 188 189 190 191 192 193 194
	int cnt = atomic_read(&ubufs->kref.refcount);

	/*
	 * Trigger polling thread if guest stopped submitting new buffers:
	 * in this case, the refcount after decrement will eventually reach 1
	 * so here it is 2.
	 * We also trigger polling periodically after each 16 packets
	 * (the value 16 here is more or less arbitrary, it's tuned to trigger
	 * less than 10% of times).
	 */
	if (cnt <= 2 || !(cnt % 16))
		vhost_poll_queue(&vq->poll);
195
	/* set len to mark this desc buffers done DMA */
196 197
	vq->heads[ubuf->desc].len = success ?
		VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
198 199 200
	vhost_ubuf_put(ubufs);
}

201 202 203 204
/* Expects to be always run from workqueue - which acts as
 * read-size critical section for our kind of RCU. */
static void handle_tx(struct vhost_net *net)
{
205
	struct vhost_virtqueue *vq = &net->vqs[VHOST_NET_VQ_TX].vq;
206 207
	unsigned out, in, s;
	int head;
208 209 210 211 212 213 214 215 216
	struct msghdr msg = {
		.msg_name = NULL,
		.msg_namelen = 0,
		.msg_control = NULL,
		.msg_controllen = 0,
		.msg_iov = vq->iov,
		.msg_flags = MSG_DONTWAIT,
	};
	size_t len, total_len = 0;
J
Jason Wang 已提交
217
	int err;
218
	size_t hdr_size;
A
Arnd Bergmann 已提交
219
	struct socket *sock;
220
	struct vhost_ubuf_ref *uninitialized_var(ubufs);
221
	bool zcopy, zcopy_used;
A
Arnd Bergmann 已提交
222

M
Michael S. Tsirkin 已提交
223
	/* TODO: check that we are running from vhost_worker? */
224
	sock = rcu_dereference_check(vq->private_data, 1);
225 226 227 228
	if (!sock)
		return;

	mutex_lock(&vq->mutex);
M
Michael S. Tsirkin 已提交
229
	vhost_disable_notify(&net->dev, vq);
230

231
	hdr_size = vq->vhost_hlen;
232
	zcopy = vq->ubufs;
233 234

	for (;;) {
235 236
		/* Release DMAs done buffers first */
		if (zcopy)
237
			vhost_zerocopy_signal_used(net, vq);
238

239 240 241 242
		head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
					 ARRAY_SIZE(vq->iov),
					 &out, &in,
					 NULL, NULL);
243
		/* On error, stop handling until the next kick. */
244
		if (unlikely(head < 0))
245
			break;
246 247
		/* Nothing new?  Wait for eventfd to tell us they refilled. */
		if (head == vq->num) {
248 249 250 251 252 253 254 255
			int num_pends;

			/* If more outstanding DMAs, queue the work.
			 * Handle upend_idx wrap around
			 */
			num_pends = likely(vq->upend_idx >= vq->done_idx) ?
				    (vq->upend_idx - vq->done_idx) :
				    (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
J
Jason Wang 已提交
256
			if (unlikely(num_pends > VHOST_MAX_PEND))
257
				break;
M
Michael S. Tsirkin 已提交
258 259
			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
				vhost_disable_notify(&net->dev, vq);
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
				continue;
			}
			break;
		}
		if (in) {
			vq_err(vq, "Unexpected descriptor format for TX: "
			       "out %d, int %d\n", out, in);
			break;
		}
		/* Skip header. TODO: support TSO. */
		s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
		msg.msg_iovlen = out;
		len = iov_length(vq->iov, out);
		/* Sanity check */
		if (!len) {
			vq_err(vq, "Unexpected header len for TX: "
			       "%zd expected %zd\n",
			       iov_length(vq->hdr, s), hdr_size);
			break;
		}
280 281 282
		zcopy_used = zcopy && (len >= VHOST_GOODCOPY_LEN ||
				       vq->upend_idx != vq->done_idx);

283
		/* use msg_control to pass vhost zerocopy ubuf info to skb */
284
		if (zcopy_used) {
285
			vq->heads[vq->upend_idx].id = head;
286 287
			if (!vhost_net_tx_select_zcopy(net) ||
			    len < VHOST_GOODCOPY_LEN) {
288 289 290 291 292 293 294
				/* copy don't need to wait for DMA done */
				vq->heads[vq->upend_idx].len =
							VHOST_DMA_DONE_LEN;
				msg.msg_control = NULL;
				msg.msg_controllen = 0;
				ubufs = NULL;
			} else {
295 296
				struct ubuf_info *ubuf;
				ubuf = vq->ubuf_info + vq->upend_idx;
297

298 299
				vq->heads[vq->upend_idx].len =
					VHOST_DMA_IN_PROGRESS;
300
				ubuf->callback = vhost_zerocopy_callback;
301
				ubuf->ctx = vq->ubufs;
302 303 304 305 306 307 308 309
				ubuf->desc = vq->upend_idx;
				msg.msg_control = ubuf;
				msg.msg_controllen = sizeof(ubuf);
				ubufs = vq->ubufs;
				kref_get(&ubufs->kref);
			}
			vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
		}
310 311 312
		/* TODO: Check specific error and bomb out unless ENOBUFS? */
		err = sock->ops->sendmsg(NULL, sock, &msg, len);
		if (unlikely(err < 0)) {
313
			if (zcopy_used) {
314 315 316 317 318
				if (ubufs)
					vhost_ubuf_put(ubufs);
				vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
					UIO_MAXIOV;
			}
319
			vhost_discard_vq_desc(vq, 1);
320 321 322
			break;
		}
		if (err != len)
323 324
			pr_debug("Truncated TX packet: "
				 " len %d != %zd\n", err, len);
325
		if (!zcopy_used)
326
			vhost_add_used_and_signal(&net->dev, vq, head, 0);
327
		else
328
			vhost_zerocopy_signal_used(net, vq);
329
		total_len += len;
330
		vhost_net_tx_packet(net);
331 332 333 334 335 336 337 338 339
		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
			vhost_poll_queue(&vq->poll);
			break;
		}
	}

	mutex_unlock(&vq->mutex);
}

340 341 342 343
static int peek_head_len(struct sock *sk)
{
	struct sk_buff *head;
	int len = 0;
344
	unsigned long flags;
345

346
	spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
347
	head = skb_peek(&sk->sk_receive_queue);
B
Basil Gor 已提交
348
	if (likely(head)) {
349
		len = head->len;
B
Basil Gor 已提交
350 351 352 353
		if (vlan_tx_tag_present(head))
			len += VLAN_HLEN;
	}

354
	spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
355 356 357 358 359 360 361 362 363 364
	return len;
}

/* This is a multi-buffer version of vhost_get_desc, that works if
 *	vq has read descriptors only.
 * @vq		- the relevant virtqueue
 * @datalen	- data length we'll be reading
 * @iovcount	- returned count of io vectors we fill
 * @log		- vhost log
 * @log_num	- log offset
365
 * @quota       - headcount quota, 1 for big buffer
366 367 368 369 370 371 372
 *	returns number of buffer heads allocated, negative on error
 */
static int get_rx_bufs(struct vhost_virtqueue *vq,
		       struct vring_used_elem *heads,
		       int datalen,
		       unsigned *iovcount,
		       struct vhost_log *log,
373 374
		       unsigned *log_num,
		       unsigned int quota)
375 376 377 378 379 380 381
{
	unsigned int out, in;
	int seg = 0;
	int headcount = 0;
	unsigned d;
	int r, nlogs = 0;

382
	while (datalen > 0 && headcount < quota) {
J
Jason Wang 已提交
383
		if (unlikely(seg >= UIO_MAXIOV)) {
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
			r = -ENOBUFS;
			goto err;
		}
		d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
				      ARRAY_SIZE(vq->iov) - seg, &out,
				      &in, log, log_num);
		if (d == vq->num) {
			r = 0;
			goto err;
		}
		if (unlikely(out || in <= 0)) {
			vq_err(vq, "unexpected descriptor format for RX: "
				"out %d, in %d\n", out, in);
			r = -EINVAL;
			goto err;
		}
		if (unlikely(log)) {
			nlogs += *log_num;
			log += *log_num;
		}
		heads[headcount].id = d;
		heads[headcount].len = iov_length(vq->iov + seg, in);
		datalen -= heads[headcount].len;
		++headcount;
		seg += in;
	}
	heads[headcount - 1].len += datalen;
	*iovcount = seg;
	if (unlikely(log))
		*log_num = nlogs;
	return headcount;
err:
	vhost_discard_vq_desc(vq, headcount);
	return r;
}

420 421
/* Expects to be always run from workqueue - which acts as
 * read-size critical section for our kind of RCU. */
422
static void handle_rx(struct vhost_net *net)
423
{
424
	struct vhost_virtqueue *vq = &net->vqs[VHOST_NET_VQ_RX].vq;
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
	unsigned uninitialized_var(in), log;
	struct vhost_log *vq_log;
	struct msghdr msg = {
		.msg_name = NULL,
		.msg_namelen = 0,
		.msg_control = NULL, /* FIXME: get and handle RX aux data. */
		.msg_controllen = 0,
		.msg_iov = vq->iov,
		.msg_flags = MSG_DONTWAIT,
	};
	struct virtio_net_hdr_mrg_rxbuf hdr = {
		.hdr.flags = 0,
		.hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
	};
	size_t total_len = 0;
440 441
	int err, mergeable;
	s16 headcount;
442 443
	size_t vhost_hlen, sock_hlen;
	size_t vhost_len, sock_len;
M
Michael S. Tsirkin 已提交
444 445
	/* TODO: check that we are running from vhost_worker? */
	struct socket *sock = rcu_dereference_check(vq->private_data, 1);
K
Krishna Kumar 已提交
446

447
	if (!sock)
448 449 450
		return;

	mutex_lock(&vq->mutex);
M
Michael S. Tsirkin 已提交
451
	vhost_disable_notify(&net->dev, vq);
452 453 454 455 456
	vhost_hlen = vq->vhost_hlen;
	sock_hlen = vq->sock_hlen;

	vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
		vq->log : NULL;
457
	mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
458 459 460 461 462

	while ((sock_len = peek_head_len(sock->sk))) {
		sock_len += sock_hlen;
		vhost_len = sock_len + vhost_hlen;
		headcount = get_rx_bufs(vq, vq->heads, vhost_len,
463 464
					&in, vq_log, &log,
					likely(mergeable) ? UIO_MAXIOV : 1);
465 466 467 468 469
		/* On error, stop handling until the next kick. */
		if (unlikely(headcount < 0))
			break;
		/* OK, now we need to know about added descriptors. */
		if (!headcount) {
M
Michael S. Tsirkin 已提交
470
			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
471 472
				/* They have slipped one in as we were
				 * doing that: check again. */
M
Michael S. Tsirkin 已提交
473
				vhost_disable_notify(&net->dev, vq);
474 475 476 477 478 479 480 481 482 483 484 485
				continue;
			}
			/* Nothing new?  Wait for eventfd to tell us
			 * they refilled. */
			break;
		}
		/* We don't need to be notified again. */
		if (unlikely((vhost_hlen)))
			/* Skip header. TODO: support TSO. */
			move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
		else
			/* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
J
Jason Wang 已提交
486
			 * needed because recvmsg can modify msg_iov. */
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
			copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
		msg.msg_iovlen = in;
		err = sock->ops->recvmsg(NULL, sock, &msg,
					 sock_len, MSG_DONTWAIT | MSG_TRUNC);
		/* Userspace might have consumed the packet meanwhile:
		 * it's not supposed to do this usually, but might be hard
		 * to prevent. Discard data we got (if any) and keep going. */
		if (unlikely(err != sock_len)) {
			pr_debug("Discarded rx packet: "
				 " len %d, expected %zd\n", err, sock_len);
			vhost_discard_vq_desc(vq, headcount);
			continue;
		}
		if (unlikely(vhost_hlen) &&
		    memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
				      vhost_hlen)) {
			vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
			       vq->iov->iov_base);
			break;
		}
		/* TODO: Should check and handle checksum. */
508
		if (likely(mergeable) &&
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
		    memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
				      offsetof(typeof(hdr), num_buffers),
				      sizeof hdr.num_buffers)) {
			vq_err(vq, "Failed num_buffers write");
			vhost_discard_vq_desc(vq, headcount);
			break;
		}
		vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
					    headcount);
		if (unlikely(vq_log))
			vhost_log_write(vq, vq_log, log, vhost_len);
		total_len += vhost_len;
		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
			vhost_poll_queue(&vq->poll);
			break;
		}
	}

	mutex_unlock(&vq->mutex);
}

530
static void handle_tx_kick(struct vhost_work *work)
531
{
532 533 534 535
	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
						  poll.work);
	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);

536 537 538
	handle_tx(net);
}

539
static void handle_rx_kick(struct vhost_work *work)
540
{
541 542 543 544
	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
						  poll.work);
	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);

545 546 547
	handle_rx(net);
}

548
static void handle_tx_net(struct vhost_work *work)
549
{
550 551
	struct vhost_net *net = container_of(work, struct vhost_net,
					     poll[VHOST_NET_VQ_TX].work);
552 553 554
	handle_tx(net);
}

555
static void handle_rx_net(struct vhost_work *work)
556
{
557 558
	struct vhost_net *net = container_of(work, struct vhost_net,
					     poll[VHOST_NET_VQ_RX].work);
559 560 561 562 563 564
	handle_rx(net);
}

static int vhost_net_open(struct inode *inode, struct file *f)
{
	struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
565
	struct vhost_dev *dev;
566
	struct vhost_virtqueue **vqs;
567
	int r;
568

569 570
	if (!n)
		return -ENOMEM;
571 572 573 574 575
	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
	if (!vqs) {
		kfree(n);
		return -ENOMEM;
	}
576 577

	dev = &n->dev;
578 579 580 581 582
	vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
	vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
	n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
	n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
	r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
583 584
	if (r < 0) {
		kfree(n);
585
		kfree(vqs);
586 587 588
		return r;
	}

589 590
	vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
	vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
591 592 593 594 595 596 597 598 599

	f->private_data = n;

	return 0;
}

static void vhost_net_disable_vq(struct vhost_net *n,
				 struct vhost_virtqueue *vq)
{
600 601 602
	struct vhost_net_virtqueue *nvq =
		container_of(vq, struct vhost_net_virtqueue, vq);
	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
603 604
	if (!vq->private_data)
		return;
J
Jason Wang 已提交
605
	vhost_poll_stop(poll);
606 607
}

608
static int vhost_net_enable_vq(struct vhost_net *n,
609 610
				struct vhost_virtqueue *vq)
{
611 612 613
	struct vhost_net_virtqueue *nvq =
		container_of(vq, struct vhost_net_virtqueue, vq);
	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
A
Arnd Bergmann 已提交
614 615 616 617
	struct socket *sock;

	sock = rcu_dereference_protected(vq->private_data,
					 lockdep_is_held(&vq->mutex));
618
	if (!sock)
619 620
		return 0;

J
Jason Wang 已提交
621
	return vhost_poll_start(poll, sock->file);
622 623 624 625 626 627 628 629
}

static struct socket *vhost_net_stop_vq(struct vhost_net *n,
					struct vhost_virtqueue *vq)
{
	struct socket *sock;

	mutex_lock(&vq->mutex);
A
Arnd Bergmann 已提交
630 631
	sock = rcu_dereference_protected(vq->private_data,
					 lockdep_is_held(&vq->mutex));
632 633 634 635 636 637 638 639 640
	vhost_net_disable_vq(n, vq);
	rcu_assign_pointer(vq->private_data, NULL);
	mutex_unlock(&vq->mutex);
	return sock;
}

static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
			   struct socket **rx_sock)
{
641 642
	*tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
	*rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
643 644 645 646 647
}

static void vhost_net_flush_vq(struct vhost_net *n, int index)
{
	vhost_poll_flush(n->poll + index);
648
	vhost_poll_flush(&n->vqs[index].vq.poll);
649 650 651 652 653 654
}

static void vhost_net_flush(struct vhost_net *n)
{
	vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
	vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
655 656
	if (n->vqs[VHOST_NET_VQ_TX].vq.ubufs) {
		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
657
		n->tx_flush = true;
658
		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
659
		/* Wait for all lower device DMAs done. */
660 661
		vhost_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].vq.ubufs);
		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
662
		n->tx_flush = false;
663 664
		kref_init(&n->vqs[VHOST_NET_VQ_TX].vq.ubufs->kref);
		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
665
	}
666 667 668 669 670 671 672 673 674 675
}

static int vhost_net_release(struct inode *inode, struct file *f)
{
	struct vhost_net *n = f->private_data;
	struct socket *tx_sock;
	struct socket *rx_sock;

	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
676
	vhost_dev_stop(&n->dev);
677
	vhost_dev_cleanup(&n->dev, false);
678 679 680 681 682 683 684
	if (tx_sock)
		fput(tx_sock->file);
	if (rx_sock)
		fput(rx_sock->file);
	/* We do an extra flush before freeing memory,
	 * since jobs can re-queue themselves. */
	vhost_net_flush(n);
685
	kfree(n->dev.vqs);
686 687 688 689 690 691 692 693 694 695 696 697
	kfree(n);
	return 0;
}

static struct socket *get_raw_socket(int fd)
{
	struct {
		struct sockaddr_ll sa;
		char  buf[MAX_ADDR_LEN];
	} uaddr;
	int uaddr_len = sizeof uaddr, r;
	struct socket *sock = sockfd_lookup(fd, &r);
K
Krishna Kumar 已提交
698

699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
	if (!sock)
		return ERR_PTR(-ENOTSOCK);

	/* Parameter checking */
	if (sock->sk->sk_type != SOCK_RAW) {
		r = -ESOCKTNOSUPPORT;
		goto err;
	}

	r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
			       &uaddr_len, 0);
	if (r)
		goto err;

	if (uaddr.sa.sll_family != AF_PACKET) {
		r = -EPFNOSUPPORT;
		goto err;
	}
	return sock;
err:
	fput(sock->file);
	return ERR_PTR(r);
}

A
Arnd Bergmann 已提交
723
static struct socket *get_tap_socket(int fd)
724 725 726
{
	struct file *file = fget(fd);
	struct socket *sock;
K
Krishna Kumar 已提交
727

728 729 730
	if (!file)
		return ERR_PTR(-EBADF);
	sock = tun_get_socket(file);
A
Arnd Bergmann 已提交
731 732 733
	if (!IS_ERR(sock))
		return sock;
	sock = macvtap_get_socket(file);
734 735 736 737 738 739 740 741
	if (IS_ERR(sock))
		fput(file);
	return sock;
}

static struct socket *get_socket(int fd)
{
	struct socket *sock;
K
Krishna Kumar 已提交
742

743 744 745 746 747 748
	/* special case to disable backend */
	if (fd == -1)
		return NULL;
	sock = get_raw_socket(fd);
	if (!IS_ERR(sock))
		return sock;
A
Arnd Bergmann 已提交
749
	sock = get_tap_socket(fd);
750 751 752 753 754 755 756 757 758
	if (!IS_ERR(sock))
		return sock;
	return ERR_PTR(-ENOTSOCK);
}

static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
{
	struct socket *sock, *oldsock;
	struct vhost_virtqueue *vq;
759
	struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
760 761 762 763 764 765 766 767 768 769 770
	int r;

	mutex_lock(&n->dev.mutex);
	r = vhost_dev_check_owner(&n->dev);
	if (r)
		goto err;

	if (index >= VHOST_NET_VQ_MAX) {
		r = -ENOBUFS;
		goto err;
	}
771
	vq = &n->vqs[index].vq;
772 773 774 775 776
	mutex_lock(&vq->mutex);

	/* Verify that ring has been setup correctly. */
	if (!vhost_vq_access_ok(vq)) {
		r = -EFAULT;
777
		goto err_vq;
778 779 780 781
	}
	sock = get_socket(fd);
	if (IS_ERR(sock)) {
		r = PTR_ERR(sock);
782
		goto err_vq;
783 784 785
	}

	/* start polling new socket */
A
Arnd Bergmann 已提交
786 787
	oldsock = rcu_dereference_protected(vq->private_data,
					    lockdep_is_held(&vq->mutex));
788
	if (sock != oldsock) {
789 790 791 792 793
		ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock));
		if (IS_ERR(ubufs)) {
			r = PTR_ERR(ubufs);
			goto err_ubufs;
		}
794

K
Krishna Kumar 已提交
795 796
		vhost_net_disable_vq(n, vq);
		rcu_assign_pointer(vq->private_data, sock);
797 798
		r = vhost_init_used(vq);
		if (r)
799
			goto err_used;
800 801 802
		r = vhost_net_enable_vq(n, vq);
		if (r)
			goto err_used;
803 804 805

		oldubufs = vq->ubufs;
		vq->ubufs = ubufs;
806 807 808

		n->tx_packets = 0;
		n->tx_zcopy_err = 0;
809
		n->tx_flush = false;
J
Jeff Dike 已提交
810
	}
811

812 813
	mutex_unlock(&vq->mutex);

814
	if (oldubufs) {
815
		vhost_ubuf_put_and_wait(oldubufs);
816
		mutex_lock(&vq->mutex);
817
		vhost_zerocopy_signal_used(n, vq);
818 819
		mutex_unlock(&vq->mutex);
	}
820

821 822 823 824
	if (oldsock) {
		vhost_net_flush_vq(n, index);
		fput(oldsock->file);
	}
825

826 827 828
	mutex_unlock(&n->dev.mutex);
	return 0;

829 830 831 832 833
err_used:
	rcu_assign_pointer(vq->private_data, oldsock);
	vhost_net_enable_vq(n, vq);
	if (ubufs)
		vhost_ubuf_put_and_wait(ubufs);
834 835
err_ubufs:
	fput(sock->file);
836 837
err_vq:
	mutex_unlock(&vq->mutex);
838 839 840 841 842 843 844 845 846 847
err:
	mutex_unlock(&n->dev.mutex);
	return r;
}

static long vhost_net_reset_owner(struct vhost_net *n)
{
	struct socket *tx_sock = NULL;
	struct socket *rx_sock = NULL;
	long err;
K
Krishna Kumar 已提交
848

849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
	mutex_lock(&n->dev.mutex);
	err = vhost_dev_check_owner(&n->dev);
	if (err)
		goto done;
	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
	err = vhost_dev_reset_owner(&n->dev);
done:
	mutex_unlock(&n->dev.mutex);
	if (tx_sock)
		fput(tx_sock->file);
	if (rx_sock)
		fput(rx_sock->file);
	return err;
}

static int vhost_net_set_features(struct vhost_net *n, u64 features)
{
867
	size_t vhost_hlen, sock_hlen, hdr_len;
868
	int i;
869 870 871 872 873 874 875 876 877 878 879 880 881

	hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ?
			sizeof(struct virtio_net_hdr_mrg_rxbuf) :
			sizeof(struct virtio_net_hdr);
	if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
		/* vhost provides vnet_hdr */
		vhost_hlen = hdr_len;
		sock_hlen = 0;
	} else {
		/* socket provides vnet_hdr */
		vhost_hlen = 0;
		sock_hlen = hdr_len;
	}
882 883 884 885 886 887 888 889 890
	mutex_lock(&n->dev.mutex);
	if ((features & (1 << VHOST_F_LOG_ALL)) &&
	    !vhost_log_access_ok(&n->dev)) {
		mutex_unlock(&n->dev.mutex);
		return -EFAULT;
	}
	n->dev.acked_features = features;
	smp_wmb();
	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
891 892 893 894
		mutex_lock(&n->vqs[i].vq.mutex);
		n->vqs[i].vq.vhost_hlen = vhost_hlen;
		n->vqs[i].vq.sock_hlen = sock_hlen;
		mutex_unlock(&n->vqs[i].vq.mutex);
895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
	}
	vhost_net_flush(n);
	mutex_unlock(&n->dev.mutex);
	return 0;
}

static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
			    unsigned long arg)
{
	struct vhost_net *n = f->private_data;
	void __user *argp = (void __user *)arg;
	u64 __user *featurep = argp;
	struct vhost_vring_file backend;
	u64 features;
	int r;
K
Krishna Kumar 已提交
910

911 912
	switch (ioctl) {
	case VHOST_NET_SET_BACKEND:
913 914
		if (copy_from_user(&backend, argp, sizeof backend))
			return -EFAULT;
915 916
		return vhost_net_set_backend(n, backend.index, backend.fd);
	case VHOST_GET_FEATURES:
917
		features = VHOST_NET_FEATURES;
918 919 920
		if (copy_to_user(featurep, &features, sizeof features))
			return -EFAULT;
		return 0;
921
	case VHOST_SET_FEATURES:
922 923
		if (copy_from_user(&features, featurep, sizeof features))
			return -EFAULT;
924
		if (features & ~VHOST_NET_FEATURES)
925 926 927 928 929 930
			return -EOPNOTSUPP;
		return vhost_net_set_features(n, features);
	case VHOST_RESET_OWNER:
		return vhost_net_reset_owner(n);
	default:
		mutex_lock(&n->dev.mutex);
931 932 933 934 935
		r = vhost_dev_ioctl(&n->dev, ioctl, argp);
		if (r == -ENOIOCTLCMD)
			r = vhost_vring_ioctl(&n->dev, ioctl, argp);
		else
			vhost_net_flush(n);
936 937 938 939 940 941 942 943 944 945 946 947 948
		mutex_unlock(&n->dev.mutex);
		return r;
	}
}

#ifdef CONFIG_COMPAT
static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
				   unsigned long arg)
{
	return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
}
#endif

949
static const struct file_operations vhost_net_fops = {
950 951 952 953 954 955 956
	.owner          = THIS_MODULE,
	.release        = vhost_net_release,
	.unlocked_ioctl = vhost_net_ioctl,
#ifdef CONFIG_COMPAT
	.compat_ioctl   = vhost_net_compat_ioctl,
#endif
	.open           = vhost_net_open,
957
	.llseek		= noop_llseek,
958 959 960
};

static struct miscdevice vhost_net_misc = {
961 962 963
	.minor = VHOST_NET_MINOR,
	.name = "vhost-net",
	.fops = &vhost_net_fops,
964 965
};

C
Christoph Hellwig 已提交
966
static int vhost_net_init(void)
967
{
968 969
	if (experimental_zcopytx)
		vhost_enable_zcopy(VHOST_NET_VQ_TX);
970
	return misc_register(&vhost_net_misc);
971 972 973
}
module_init(vhost_net_init);

C
Christoph Hellwig 已提交
974
static void vhost_net_exit(void)
975 976 977 978 979 980 981 982 983
{
	misc_deregister(&vhost_net_misc);
}
module_exit(vhost_net_exit);

MODULE_VERSION("0.0.1");
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Michael S. Tsirkin");
MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
984 985
MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
MODULE_ALIAS("devname:vhost-net");