net.c 25.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (C) 2009 Red Hat, Inc.
 * Author: Michael S. Tsirkin <mst@redhat.com>
 *
 * This work is licensed under the terms of the GNU GPL, version 2.
 *
 * virtio-net server in host kernel.
 */

#include <linux/compat.h>
#include <linux/eventfd.h>
#include <linux/vhost.h>
#include <linux/virtio_net.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
15
#include <linux/moduleparam.h>
16 17 18 19
#include <linux/mutex.h>
#include <linux/workqueue.h>
#include <linux/rcupdate.h>
#include <linux/file.h>
20
#include <linux/slab.h>
21 22 23 24 25

#include <linux/net.h>
#include <linux/if_packet.h>
#include <linux/if_arp.h>
#include <linux/if_tun.h>
A
Arnd Bergmann 已提交
26
#include <linux/if_macvlan.h>
B
Basil Gor 已提交
27
#include <linux/if_vlan.h>
28 29 30 31 32

#include <net/sock.h>

#include "vhost.h"

33
static int experimental_zcopytx = 1;
34
module_param(experimental_zcopytx, int, 0444);
35 36
MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
		                       " 1 -Enable; 0 - Disable");
37

38 39 40 41
/* Max number of bytes transferred before requeueing the job.
 * Using this limit prevents one virtqueue from starving others. */
#define VHOST_NET_WEIGHT 0x80000

42 43 44 45
/* MAX number of TX used buffers for outstanding zerocopy */
#define VHOST_MAX_PEND 128
#define VHOST_GOODCOPY_LEN 256

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
/*
 * For transmit, used buffer len is unused; we override it to track buffer
 * status internally; used for zerocopy tx only.
 */
/* Lower device DMA failed */
#define VHOST_DMA_FAILED_LEN	3
/* Lower device DMA done */
#define VHOST_DMA_DONE_LEN	2
/* Lower device DMA in progress */
#define VHOST_DMA_IN_PROGRESS	1
/* Buffer unused */
#define VHOST_DMA_CLEAR_LEN	0

#define VHOST_DMA_IS_DONE(len) ((len) >= VHOST_DMA_DONE_LEN)

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
enum {
	VHOST_NET_VQ_RX = 0,
	VHOST_NET_VQ_TX = 1,
	VHOST_NET_VQ_MAX = 2,
};

enum vhost_net_poll_state {
	VHOST_NET_POLL_DISABLED = 0,
	VHOST_NET_POLL_STARTED = 1,
	VHOST_NET_POLL_STOPPED = 2,
};

struct vhost_net {
	struct vhost_dev dev;
	struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
	struct vhost_poll poll[VHOST_NET_VQ_MAX];
	/* Tells us whether we are polling a socket for TX.
	 * We only do this when socket buffer fills up.
	 * Protected by tx vq lock. */
	enum vhost_net_poll_state tx_poll_state;
81 82 83 84 85 86
	/* Number of TX recently submitted.
	 * Protected by tx vq lock. */
	unsigned tx_packets;
	/* Number of times zerocopy TX recently failed.
	 * Protected by tx vq lock. */
	unsigned tx_zcopy_err;
87 88
	/* Flush in progress. Protected by tx vq lock. */
	bool tx_flush;
89 90
};

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
static void vhost_net_tx_packet(struct vhost_net *net)
{
	++net->tx_packets;
	if (net->tx_packets < 1024)
		return;
	net->tx_packets = 0;
	net->tx_zcopy_err = 0;
}

static void vhost_net_tx_err(struct vhost_net *net)
{
	++net->tx_zcopy_err;
}

static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
{
107 108 109 110 111
	/* TX flush waits for outstanding DMAs to be done.
	 * Don't start new DMAs.
	 */
	return !net->tx_flush &&
		net->tx_packets / 64 >= net->tx_zcopy_err;
112 113
}

114 115 116 117 118 119
static bool vhost_sock_zcopy(struct socket *sock)
{
	return unlikely(experimental_zcopytx) &&
		sock_flag(sock->sk, SOCK_ZEROCOPY);
}

120 121 122 123 124 125
/* Pop first len bytes from iovec. Return number of segments used. */
static int move_iovec_hdr(struct iovec *from, struct iovec *to,
			  size_t len, int iov_count)
{
	int seg = 0;
	size_t size;
K
Krishna Kumar 已提交
126

127 128 129 130 131 132 133 134 135 136 137 138 139
	while (len && seg < iov_count) {
		size = min(from->iov_len, len);
		to->iov_base = from->iov_base;
		to->iov_len = size;
		from->iov_len -= size;
		from->iov_base += size;
		len -= size;
		++from;
		++to;
		++seg;
	}
	return seg;
}
140 141 142 143 144 145
/* Copy iovec entries for len bytes from iovec. */
static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
			   size_t len, int iovcount)
{
	int seg = 0;
	size_t size;
K
Krishna Kumar 已提交
146

147 148 149 150 151 152 153 154 155 156
	while (len && seg < iovcount) {
		size = min(from->iov_len, len);
		to->iov_base = from->iov_base;
		to->iov_len = size;
		len -= size;
		++from;
		++to;
		++seg;
	}
}
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175

/* Caller must have TX VQ lock */
static void tx_poll_stop(struct vhost_net *net)
{
	if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
		return;
	vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
	net->tx_poll_state = VHOST_NET_POLL_STOPPED;
}

/* Caller must have TX VQ lock */
static void tx_poll_start(struct vhost_net *net, struct socket *sock)
{
	if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
		return;
	vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
	net->tx_poll_state = VHOST_NET_POLL_STARTED;
}

176 177 178 179 180
/* In case of DMA done not in order in lower device driver for some reason.
 * upend_idx is used to track end of used idx, done_idx is used to track head
 * of used idx. Once lower device DMA done contiguously, we will signal KVM
 * guest used idx.
 */
181 182
static int vhost_zerocopy_signal_used(struct vhost_net *net,
				      struct vhost_virtqueue *vq)
183 184 185 186 187
{
	int i;
	int j = 0;

	for (i = vq->done_idx; i != vq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
188 189
		if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
			vhost_net_tx_err(net);
190 191 192 193 194 195 196 197 198 199 200 201 202
		if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
			vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
			vhost_add_used_and_signal(vq->dev, vq,
						  vq->heads[i].id, 0);
			++j;
		} else
			break;
	}
	if (j)
		vq->done_idx = i;
	return j;
}

203
static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
204 205 206
{
	struct vhost_ubuf_ref *ubufs = ubuf->ctx;
	struct vhost_virtqueue *vq = ubufs->vq;
207 208 209 210 211 212 213 214 215 216 217 218
	int cnt = atomic_read(&ubufs->kref.refcount);

	/*
	 * Trigger polling thread if guest stopped submitting new buffers:
	 * in this case, the refcount after decrement will eventually reach 1
	 * so here it is 2.
	 * We also trigger polling periodically after each 16 packets
	 * (the value 16 here is more or less arbitrary, it's tuned to trigger
	 * less than 10% of times).
	 */
	if (cnt <= 2 || !(cnt % 16))
		vhost_poll_queue(&vq->poll);
219
	/* set len to mark this desc buffers done DMA */
220 221
	vq->heads[ubuf->desc].len = success ?
		VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
222 223 224
	vhost_ubuf_put(ubufs);
}

225 226 227 228 229
/* Expects to be always run from workqueue - which acts as
 * read-size critical section for our kind of RCU. */
static void handle_tx(struct vhost_net *net)
{
	struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
230 231
	unsigned out, in, s;
	int head;
232 233 234 235 236 237 238 239 240 241 242
	struct msghdr msg = {
		.msg_name = NULL,
		.msg_namelen = 0,
		.msg_control = NULL,
		.msg_controllen = 0,
		.msg_iov = vq->iov,
		.msg_flags = MSG_DONTWAIT,
	};
	size_t len, total_len = 0;
	int err, wmem;
	size_t hdr_size;
A
Arnd Bergmann 已提交
243
	struct socket *sock;
244
	struct vhost_ubuf_ref *uninitialized_var(ubufs);
245
	bool zcopy, zcopy_used;
A
Arnd Bergmann 已提交
246

M
Michael S. Tsirkin 已提交
247
	/* TODO: check that we are running from vhost_worker? */
248
	sock = rcu_dereference_check(vq->private_data, 1);
249 250 251 252
	if (!sock)
		return;

	wmem = atomic_read(&sock->sk->sk_wmem_alloc);
253 254 255 256
	if (wmem >= sock->sk->sk_sndbuf) {
		mutex_lock(&vq->mutex);
		tx_poll_start(net, sock);
		mutex_unlock(&vq->mutex);
257
		return;
258
	}
259 260

	mutex_lock(&vq->mutex);
M
Michael S. Tsirkin 已提交
261
	vhost_disable_notify(&net->dev, vq);
262

263
	if (wmem < sock->sk->sk_sndbuf / 2)
264
		tx_poll_stop(net);
265
	hdr_size = vq->vhost_hlen;
266
	zcopy = vq->ubufs;
267 268

	for (;;) {
269 270
		/* Release DMAs done buffers first */
		if (zcopy)
271
			vhost_zerocopy_signal_used(net, vq);
272

273 274 275 276
		head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
					 ARRAY_SIZE(vq->iov),
					 &out, &in,
					 NULL, NULL);
277
		/* On error, stop handling until the next kick. */
278
		if (unlikely(head < 0))
279
			break;
280 281
		/* Nothing new?  Wait for eventfd to tell us they refilled. */
		if (head == vq->num) {
282 283
			int num_pends;

284 285 286 287 288 289
			wmem = atomic_read(&sock->sk->sk_wmem_alloc);
			if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
				tx_poll_start(net, sock);
				set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
				break;
			}
290 291 292 293 294 295 296
			/* If more outstanding DMAs, queue the work.
			 * Handle upend_idx wrap around
			 */
			num_pends = likely(vq->upend_idx >= vq->done_idx) ?
				    (vq->upend_idx - vq->done_idx) :
				    (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
			if (unlikely(num_pends > VHOST_MAX_PEND)) {
297 298 299 300
				tx_poll_start(net, sock);
				set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
				break;
			}
M
Michael S. Tsirkin 已提交
301 302
			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
				vhost_disable_notify(&net->dev, vq);
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
				continue;
			}
			break;
		}
		if (in) {
			vq_err(vq, "Unexpected descriptor format for TX: "
			       "out %d, int %d\n", out, in);
			break;
		}
		/* Skip header. TODO: support TSO. */
		s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
		msg.msg_iovlen = out;
		len = iov_length(vq->iov, out);
		/* Sanity check */
		if (!len) {
			vq_err(vq, "Unexpected header len for TX: "
			       "%zd expected %zd\n",
			       iov_length(vq->hdr, s), hdr_size);
			break;
		}
323 324 325
		zcopy_used = zcopy && (len >= VHOST_GOODCOPY_LEN ||
				       vq->upend_idx != vq->done_idx);

326
		/* use msg_control to pass vhost zerocopy ubuf info to skb */
327
		if (zcopy_used) {
328
			vq->heads[vq->upend_idx].id = head;
329 330
			if (!vhost_net_tx_select_zcopy(net) ||
			    len < VHOST_GOODCOPY_LEN) {
331 332 333 334 335 336 337 338 339
				/* copy don't need to wait for DMA done */
				vq->heads[vq->upend_idx].len =
							VHOST_DMA_DONE_LEN;
				msg.msg_control = NULL;
				msg.msg_controllen = 0;
				ubufs = NULL;
			} else {
				struct ubuf_info *ubuf = &vq->ubuf_info[head];

340 341
				vq->heads[vq->upend_idx].len =
					VHOST_DMA_IN_PROGRESS;
342
				ubuf->callback = vhost_zerocopy_callback;
343
				ubuf->ctx = vq->ubufs;
344 345 346 347 348 349 350 351
				ubuf->desc = vq->upend_idx;
				msg.msg_control = ubuf;
				msg.msg_controllen = sizeof(ubuf);
				ubufs = vq->ubufs;
				kref_get(&ubufs->kref);
			}
			vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
		}
352 353 354
		/* TODO: Check specific error and bomb out unless ENOBUFS? */
		err = sock->ops->sendmsg(NULL, sock, &msg, len);
		if (unlikely(err < 0)) {
355
			if (zcopy_used) {
356 357 358 359 360
				if (ubufs)
					vhost_ubuf_put(ubufs);
				vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
					UIO_MAXIOV;
			}
361
			vhost_discard_vq_desc(vq, 1);
362 363
			if (err == -EAGAIN || err == -ENOBUFS)
				tx_poll_start(net, sock);
364 365 366
			break;
		}
		if (err != len)
367 368
			pr_debug("Truncated TX packet: "
				 " len %d != %zd\n", err, len);
369
		if (!zcopy_used)
370
			vhost_add_used_and_signal(&net->dev, vq, head, 0);
371
		else
372
			vhost_zerocopy_signal_used(net, vq);
373
		total_len += len;
374
		vhost_net_tx_packet(net);
375 376 377 378 379 380 381 382 383
		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
			vhost_poll_queue(&vq->poll);
			break;
		}
	}

	mutex_unlock(&vq->mutex);
}

384 385 386 387
static int peek_head_len(struct sock *sk)
{
	struct sk_buff *head;
	int len = 0;
388
	unsigned long flags;
389

390
	spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
391
	head = skb_peek(&sk->sk_receive_queue);
B
Basil Gor 已提交
392
	if (likely(head)) {
393
		len = head->len;
B
Basil Gor 已提交
394 395 396 397
		if (vlan_tx_tag_present(head))
			len += VLAN_HLEN;
	}

398
	spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
399 400 401 402 403 404 405 406 407 408
	return len;
}

/* This is a multi-buffer version of vhost_get_desc, that works if
 *	vq has read descriptors only.
 * @vq		- the relevant virtqueue
 * @datalen	- data length we'll be reading
 * @iovcount	- returned count of io vectors we fill
 * @log		- vhost log
 * @log_num	- log offset
409
 * @quota       - headcount quota, 1 for big buffer
410 411 412 413 414 415 416
 *	returns number of buffer heads allocated, negative on error
 */
static int get_rx_bufs(struct vhost_virtqueue *vq,
		       struct vring_used_elem *heads,
		       int datalen,
		       unsigned *iovcount,
		       struct vhost_log *log,
417 418
		       unsigned *log_num,
		       unsigned int quota)
419 420 421 422 423 424 425
{
	unsigned int out, in;
	int seg = 0;
	int headcount = 0;
	unsigned d;
	int r, nlogs = 0;

426
	while (datalen > 0 && headcount < quota) {
J
Jason Wang 已提交
427
		if (unlikely(seg >= UIO_MAXIOV)) {
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
			r = -ENOBUFS;
			goto err;
		}
		d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
				      ARRAY_SIZE(vq->iov) - seg, &out,
				      &in, log, log_num);
		if (d == vq->num) {
			r = 0;
			goto err;
		}
		if (unlikely(out || in <= 0)) {
			vq_err(vq, "unexpected descriptor format for RX: "
				"out %d, in %d\n", out, in);
			r = -EINVAL;
			goto err;
		}
		if (unlikely(log)) {
			nlogs += *log_num;
			log += *log_num;
		}
		heads[headcount].id = d;
		heads[headcount].len = iov_length(vq->iov + seg, in);
		datalen -= heads[headcount].len;
		++headcount;
		seg += in;
	}
	heads[headcount - 1].len += datalen;
	*iovcount = seg;
	if (unlikely(log))
		*log_num = nlogs;
	return headcount;
err:
	vhost_discard_vq_desc(vq, headcount);
	return r;
}

464 465
/* Expects to be always run from workqueue - which acts as
 * read-size critical section for our kind of RCU. */
466
static void handle_rx(struct vhost_net *net)
467
{
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
	struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
	unsigned uninitialized_var(in), log;
	struct vhost_log *vq_log;
	struct msghdr msg = {
		.msg_name = NULL,
		.msg_namelen = 0,
		.msg_control = NULL, /* FIXME: get and handle RX aux data. */
		.msg_controllen = 0,
		.msg_iov = vq->iov,
		.msg_flags = MSG_DONTWAIT,
	};
	struct virtio_net_hdr_mrg_rxbuf hdr = {
		.hdr.flags = 0,
		.hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
	};
	size_t total_len = 0;
484 485
	int err, mergeable;
	s16 headcount;
486 487
	size_t vhost_hlen, sock_hlen;
	size_t vhost_len, sock_len;
M
Michael S. Tsirkin 已提交
488 489
	/* TODO: check that we are running from vhost_worker? */
	struct socket *sock = rcu_dereference_check(vq->private_data, 1);
K
Krishna Kumar 已提交
490

491
	if (!sock)
492 493 494
		return;

	mutex_lock(&vq->mutex);
M
Michael S. Tsirkin 已提交
495
	vhost_disable_notify(&net->dev, vq);
496 497 498 499 500
	vhost_hlen = vq->vhost_hlen;
	sock_hlen = vq->sock_hlen;

	vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
		vq->log : NULL;
501
	mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
502 503 504 505 506

	while ((sock_len = peek_head_len(sock->sk))) {
		sock_len += sock_hlen;
		vhost_len = sock_len + vhost_hlen;
		headcount = get_rx_bufs(vq, vq->heads, vhost_len,
507 508
					&in, vq_log, &log,
					likely(mergeable) ? UIO_MAXIOV : 1);
509 510 511 512 513
		/* On error, stop handling until the next kick. */
		if (unlikely(headcount < 0))
			break;
		/* OK, now we need to know about added descriptors. */
		if (!headcount) {
M
Michael S. Tsirkin 已提交
514
			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
515 516
				/* They have slipped one in as we were
				 * doing that: check again. */
M
Michael S. Tsirkin 已提交
517
				vhost_disable_notify(&net->dev, vq);
518 519 520 521 522 523 524 525 526 527 528 529
				continue;
			}
			/* Nothing new?  Wait for eventfd to tell us
			 * they refilled. */
			break;
		}
		/* We don't need to be notified again. */
		if (unlikely((vhost_hlen)))
			/* Skip header. TODO: support TSO. */
			move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
		else
			/* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
J
Jason Wang 已提交
530
			 * needed because recvmsg can modify msg_iov. */
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
			copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
		msg.msg_iovlen = in;
		err = sock->ops->recvmsg(NULL, sock, &msg,
					 sock_len, MSG_DONTWAIT | MSG_TRUNC);
		/* Userspace might have consumed the packet meanwhile:
		 * it's not supposed to do this usually, but might be hard
		 * to prevent. Discard data we got (if any) and keep going. */
		if (unlikely(err != sock_len)) {
			pr_debug("Discarded rx packet: "
				 " len %d, expected %zd\n", err, sock_len);
			vhost_discard_vq_desc(vq, headcount);
			continue;
		}
		if (unlikely(vhost_hlen) &&
		    memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
				      vhost_hlen)) {
			vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
			       vq->iov->iov_base);
			break;
		}
		/* TODO: Should check and handle checksum. */
552
		if (likely(mergeable) &&
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
		    memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
				      offsetof(typeof(hdr), num_buffers),
				      sizeof hdr.num_buffers)) {
			vq_err(vq, "Failed num_buffers write");
			vhost_discard_vq_desc(vq, headcount);
			break;
		}
		vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
					    headcount);
		if (unlikely(vq_log))
			vhost_log_write(vq, vq_log, log, vhost_len);
		total_len += vhost_len;
		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
			vhost_poll_queue(&vq->poll);
			break;
		}
	}

	mutex_unlock(&vq->mutex);
}

574
static void handle_tx_kick(struct vhost_work *work)
575
{
576 577 578 579
	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
						  poll.work);
	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);

580 581 582
	handle_tx(net);
}

583
static void handle_rx_kick(struct vhost_work *work)
584
{
585 586 587 588
	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
						  poll.work);
	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);

589 590 591
	handle_rx(net);
}

592
static void handle_tx_net(struct vhost_work *work)
593
{
594 595
	struct vhost_net *net = container_of(work, struct vhost_net,
					     poll[VHOST_NET_VQ_TX].work);
596 597 598
	handle_tx(net);
}

599
static void handle_rx_net(struct vhost_work *work)
600
{
601 602
	struct vhost_net *net = container_of(work, struct vhost_net,
					     poll[VHOST_NET_VQ_RX].work);
603 604 605 606 607 608
	handle_rx(net);
}

static int vhost_net_open(struct inode *inode, struct file *f)
{
	struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
609
	struct vhost_dev *dev;
610
	int r;
611

612 613
	if (!n)
		return -ENOMEM;
614 615

	dev = &n->dev;
616 617
	n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
	n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
618
	r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
619 620 621 622 623
	if (r < 0) {
		kfree(n);
		return r;
	}

624 625
	vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
	vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
	n->tx_poll_state = VHOST_NET_POLL_DISABLED;

	f->private_data = n;

	return 0;
}

static void vhost_net_disable_vq(struct vhost_net *n,
				 struct vhost_virtqueue *vq)
{
	if (!vq->private_data)
		return;
	if (vq == n->vqs + VHOST_NET_VQ_TX) {
		tx_poll_stop(n);
		n->tx_poll_state = VHOST_NET_POLL_DISABLED;
	} else
		vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
}

static void vhost_net_enable_vq(struct vhost_net *n,
				struct vhost_virtqueue *vq)
{
A
Arnd Bergmann 已提交
648 649 650 651
	struct socket *sock;

	sock = rcu_dereference_protected(vq->private_data,
					 lockdep_is_held(&vq->mutex));
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666
	if (!sock)
		return;
	if (vq == n->vqs + VHOST_NET_VQ_TX) {
		n->tx_poll_state = VHOST_NET_POLL_STOPPED;
		tx_poll_start(n, sock);
	} else
		vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
}

static struct socket *vhost_net_stop_vq(struct vhost_net *n,
					struct vhost_virtqueue *vq)
{
	struct socket *sock;

	mutex_lock(&vq->mutex);
A
Arnd Bergmann 已提交
667 668
	sock = rcu_dereference_protected(vq->private_data,
					 lockdep_is_held(&vq->mutex));
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
	vhost_net_disable_vq(n, vq);
	rcu_assign_pointer(vq->private_data, NULL);
	mutex_unlock(&vq->mutex);
	return sock;
}

static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
			   struct socket **rx_sock)
{
	*tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
	*rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
}

static void vhost_net_flush_vq(struct vhost_net *n, int index)
{
	vhost_poll_flush(n->poll + index);
	vhost_poll_flush(&n->dev.vqs[index].poll);
}

static void vhost_net_flush(struct vhost_net *n)
{
	vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
	vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
692 693 694 695 696 697 698 699 700 701 702
	if (n->dev.vqs[VHOST_NET_VQ_TX].ubufs) {
		mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
		n->tx_flush = true;
		mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
		/* Wait for all lower device DMAs done. */
		vhost_ubuf_put_and_wait(n->dev.vqs[VHOST_NET_VQ_TX].ubufs);
		mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
		n->tx_flush = false;
		kref_init(&n->dev.vqs[VHOST_NET_VQ_TX].ubufs->kref);
		mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
	}
703 704 705 706 707 708 709 710 711 712
}

static int vhost_net_release(struct inode *inode, struct file *f)
{
	struct vhost_net *n = f->private_data;
	struct socket *tx_sock;
	struct socket *rx_sock;

	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
713
	vhost_dev_stop(&n->dev);
714
	vhost_dev_cleanup(&n->dev, false);
715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
	if (tx_sock)
		fput(tx_sock->file);
	if (rx_sock)
		fput(rx_sock->file);
	/* We do an extra flush before freeing memory,
	 * since jobs can re-queue themselves. */
	vhost_net_flush(n);
	kfree(n);
	return 0;
}

static struct socket *get_raw_socket(int fd)
{
	struct {
		struct sockaddr_ll sa;
		char  buf[MAX_ADDR_LEN];
	} uaddr;
	int uaddr_len = sizeof uaddr, r;
	struct socket *sock = sockfd_lookup(fd, &r);
K
Krishna Kumar 已提交
734

735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
	if (!sock)
		return ERR_PTR(-ENOTSOCK);

	/* Parameter checking */
	if (sock->sk->sk_type != SOCK_RAW) {
		r = -ESOCKTNOSUPPORT;
		goto err;
	}

	r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
			       &uaddr_len, 0);
	if (r)
		goto err;

	if (uaddr.sa.sll_family != AF_PACKET) {
		r = -EPFNOSUPPORT;
		goto err;
	}
	return sock;
err:
	fput(sock->file);
	return ERR_PTR(r);
}

A
Arnd Bergmann 已提交
759
static struct socket *get_tap_socket(int fd)
760 761 762
{
	struct file *file = fget(fd);
	struct socket *sock;
K
Krishna Kumar 已提交
763

764 765 766
	if (!file)
		return ERR_PTR(-EBADF);
	sock = tun_get_socket(file);
A
Arnd Bergmann 已提交
767 768 769
	if (!IS_ERR(sock))
		return sock;
	sock = macvtap_get_socket(file);
770 771 772 773 774 775 776 777
	if (IS_ERR(sock))
		fput(file);
	return sock;
}

static struct socket *get_socket(int fd)
{
	struct socket *sock;
K
Krishna Kumar 已提交
778

779 780 781 782 783 784
	/* special case to disable backend */
	if (fd == -1)
		return NULL;
	sock = get_raw_socket(fd);
	if (!IS_ERR(sock))
		return sock;
A
Arnd Bergmann 已提交
785
	sock = get_tap_socket(fd);
786 787 788 789 790 791 792 793 794
	if (!IS_ERR(sock))
		return sock;
	return ERR_PTR(-ENOTSOCK);
}

static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
{
	struct socket *sock, *oldsock;
	struct vhost_virtqueue *vq;
795
	struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812
	int r;

	mutex_lock(&n->dev.mutex);
	r = vhost_dev_check_owner(&n->dev);
	if (r)
		goto err;

	if (index >= VHOST_NET_VQ_MAX) {
		r = -ENOBUFS;
		goto err;
	}
	vq = n->vqs + index;
	mutex_lock(&vq->mutex);

	/* Verify that ring has been setup correctly. */
	if (!vhost_vq_access_ok(vq)) {
		r = -EFAULT;
813
		goto err_vq;
814 815 816 817
	}
	sock = get_socket(fd);
	if (IS_ERR(sock)) {
		r = PTR_ERR(sock);
818
		goto err_vq;
819 820 821
	}

	/* start polling new socket */
A
Arnd Bergmann 已提交
822 823
	oldsock = rcu_dereference_protected(vq->private_data,
					    lockdep_is_held(&vq->mutex));
824
	if (sock != oldsock) {
825 826 827 828 829 830 831
		ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock));
		if (IS_ERR(ubufs)) {
			r = PTR_ERR(ubufs);
			goto err_ubufs;
		}
		oldubufs = vq->ubufs;
		vq->ubufs = ubufs;
K
Krishna Kumar 已提交
832 833 834
		vhost_net_disable_vq(n, vq);
		rcu_assign_pointer(vq->private_data, sock);
		vhost_net_enable_vq(n, vq);
835 836 837 838

		r = vhost_init_used(vq);
		if (r)
			goto err_vq;
839 840 841

		n->tx_packets = 0;
		n->tx_zcopy_err = 0;
842
		n->tx_flush = false;
J
Jeff Dike 已提交
843
	}
844

845 846
	mutex_unlock(&vq->mutex);

847
	if (oldubufs) {
848
		vhost_ubuf_put_and_wait(oldubufs);
849
		mutex_lock(&vq->mutex);
850
		vhost_zerocopy_signal_used(n, vq);
851 852
		mutex_unlock(&vq->mutex);
	}
853

854 855 856 857
	if (oldsock) {
		vhost_net_flush_vq(n, index);
		fput(oldsock->file);
	}
858

859 860 861
	mutex_unlock(&n->dev.mutex);
	return 0;

862 863
err_ubufs:
	fput(sock->file);
864 865
err_vq:
	mutex_unlock(&vq->mutex);
866 867 868 869 870 871 872 873 874 875
err:
	mutex_unlock(&n->dev.mutex);
	return r;
}

static long vhost_net_reset_owner(struct vhost_net *n)
{
	struct socket *tx_sock = NULL;
	struct socket *rx_sock = NULL;
	long err;
K
Krishna Kumar 已提交
876

877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894
	mutex_lock(&n->dev.mutex);
	err = vhost_dev_check_owner(&n->dev);
	if (err)
		goto done;
	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
	err = vhost_dev_reset_owner(&n->dev);
done:
	mutex_unlock(&n->dev.mutex);
	if (tx_sock)
		fput(tx_sock->file);
	if (rx_sock)
		fput(rx_sock->file);
	return err;
}

static int vhost_net_set_features(struct vhost_net *n, u64 features)
{
895
	size_t vhost_hlen, sock_hlen, hdr_len;
896
	int i;
897 898 899 900 901 902 903 904 905 906 907 908 909

	hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ?
			sizeof(struct virtio_net_hdr_mrg_rxbuf) :
			sizeof(struct virtio_net_hdr);
	if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
		/* vhost provides vnet_hdr */
		vhost_hlen = hdr_len;
		sock_hlen = 0;
	} else {
		/* socket provides vnet_hdr */
		vhost_hlen = 0;
		sock_hlen = hdr_len;
	}
910 911 912 913 914 915 916 917 918 919
	mutex_lock(&n->dev.mutex);
	if ((features & (1 << VHOST_F_LOG_ALL)) &&
	    !vhost_log_access_ok(&n->dev)) {
		mutex_unlock(&n->dev.mutex);
		return -EFAULT;
	}
	n->dev.acked_features = features;
	smp_wmb();
	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
		mutex_lock(&n->vqs[i].mutex);
920 921
		n->vqs[i].vhost_hlen = vhost_hlen;
		n->vqs[i].sock_hlen = sock_hlen;
922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937
		mutex_unlock(&n->vqs[i].mutex);
	}
	vhost_net_flush(n);
	mutex_unlock(&n->dev.mutex);
	return 0;
}

static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
			    unsigned long arg)
{
	struct vhost_net *n = f->private_data;
	void __user *argp = (void __user *)arg;
	u64 __user *featurep = argp;
	struct vhost_vring_file backend;
	u64 features;
	int r;
K
Krishna Kumar 已提交
938

939 940
	switch (ioctl) {
	case VHOST_NET_SET_BACKEND:
941 942
		if (copy_from_user(&backend, argp, sizeof backend))
			return -EFAULT;
943 944
		return vhost_net_set_backend(n, backend.index, backend.fd);
	case VHOST_GET_FEATURES:
945
		features = VHOST_NET_FEATURES;
946 947 948
		if (copy_to_user(featurep, &features, sizeof features))
			return -EFAULT;
		return 0;
949
	case VHOST_SET_FEATURES:
950 951
		if (copy_from_user(&features, featurep, sizeof features))
			return -EFAULT;
952
		if (features & ~VHOST_NET_FEATURES)
953 954 955 956 957 958
			return -EOPNOTSUPP;
		return vhost_net_set_features(n, features);
	case VHOST_RESET_OWNER:
		return vhost_net_reset_owner(n);
	default:
		mutex_lock(&n->dev.mutex);
959 960 961 962 963
		r = vhost_dev_ioctl(&n->dev, ioctl, argp);
		if (r == -ENOIOCTLCMD)
			r = vhost_vring_ioctl(&n->dev, ioctl, argp);
		else
			vhost_net_flush(n);
964 965 966 967 968 969 970 971 972 973 974 975 976
		mutex_unlock(&n->dev.mutex);
		return r;
	}
}

#ifdef CONFIG_COMPAT
static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
				   unsigned long arg)
{
	return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
}
#endif

977
static const struct file_operations vhost_net_fops = {
978 979 980 981 982 983 984
	.owner          = THIS_MODULE,
	.release        = vhost_net_release,
	.unlocked_ioctl = vhost_net_ioctl,
#ifdef CONFIG_COMPAT
	.compat_ioctl   = vhost_net_compat_ioctl,
#endif
	.open           = vhost_net_open,
985
	.llseek		= noop_llseek,
986 987 988
};

static struct miscdevice vhost_net_misc = {
989 990 991
	.minor = VHOST_NET_MINOR,
	.name = "vhost-net",
	.fops = &vhost_net_fops,
992 993
};

C
Christoph Hellwig 已提交
994
static int vhost_net_init(void)
995
{
996 997
	if (experimental_zcopytx)
		vhost_enable_zcopy(VHOST_NET_VQ_TX);
998
	return misc_register(&vhost_net_misc);
999 1000 1001
}
module_init(vhost_net_init);

C
Christoph Hellwig 已提交
1002
static void vhost_net_exit(void)
1003 1004 1005 1006 1007 1008 1009 1010 1011
{
	misc_deregister(&vhost_net_misc);
}
module_exit(vhost_net_exit);

MODULE_VERSION("0.0.1");
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Michael S. Tsirkin");
MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
1012 1013
MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
MODULE_ALIAS("devname:vhost-net");