hyperv_transport.c 25.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/*
 * Hyper-V transport for vsock
 *
 * Hyper-V Sockets supplies a byte-stream based communication mechanism
 * between the host and the VM. This driver implements the necessary
 * support in the VM by introducing the new vsock transport.
 *
 * Copyright (c) 2017, Microsoft Corporation.
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 */
#include <linux/module.h>
#include <linux/vmalloc.h>
#include <linux/hyperv.h>
#include <net/sock.h>
#include <net/af_vsock.h>

26 27 28
/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
 * stricter requirements on the hv_sock ring buffer size of six 4K pages. Newer
 * hosts don't have this limitation; but, keep the defaults the same for compat.
29 30 31 32
 */
#define PAGE_SIZE_4K		4096
#define RINGBUFFER_HVS_RCV_SIZE (PAGE_SIZE_4K * 6)
#define RINGBUFFER_HVS_SND_SIZE (PAGE_SIZE_4K * 6)
33
#define RINGBUFFER_HVS_MAX_SIZE (PAGE_SIZE_4K * 64)
34 35 36 37

/* The MTU is 16KB per the host side's design */
#define HVS_MTU_SIZE		(1024 * 16)

38 39 40
/* How long to wait for graceful shutdown of a connection */
#define HVS_CLOSE_TIMEOUT (8 * HZ)

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
struct vmpipe_proto_header {
	u32 pkt_type;
	u32 data_size;
};

/* For recv, we use the VMBus in-place packet iterator APIs to directly copy
 * data from the ringbuffer into the userspace buffer.
 */
struct hvs_recv_buf {
	/* The header before the payload data */
	struct vmpipe_proto_header hdr;

	/* The payload */
	u8 data[HVS_MTU_SIZE];
};

/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
58 59 60
 * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the
 * guest and the host processing as one VMBUS packet is the smallest processing
 * unit.
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
 *
 * Note: the buffer can be eliminated in the future when we add new VMBus
 * ringbuffer APIs that allow us to directly copy data from userspace buffer
 * to VMBus ringbuffer.
 */
#define HVS_SEND_BUF_SIZE (PAGE_SIZE_4K - sizeof(struct vmpipe_proto_header))

struct hvs_send_buf {
	/* The header before the payload data */
	struct vmpipe_proto_header hdr;

	/* The payload */
	u8 data[HVS_SEND_BUF_SIZE];
};

#define HVS_HEADER_LEN	(sizeof(struct vmpacket_descriptor) + \
			 sizeof(struct vmpipe_proto_header))

/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
 * __hv_pkt_iter_next().
 */
#define VMBUS_PKT_TRAILER_SIZE	(sizeof(u64))

#define HVS_PKT_LEN(payload_len)	(HVS_HEADER_LEN + \
					 ALIGN((payload_len), 8) + \
					 VMBUS_PKT_TRAILER_SIZE)

union hvs_service_id {
	uuid_le	srv_id;

	struct {
		unsigned int svm_port;
		unsigned char b[sizeof(uuid_le) - sizeof(unsigned int)];
	};
};

/* Per-socket state (accessed via vsk->trans) */
struct hvsock {
	struct vsock_sock *vsk;

	uuid_le vm_srv_id;
	uuid_le host_srv_id;

	struct vmbus_channel *chan;
	struct vmpacket_descriptor *recv_desc;

	/* The length of the payload not delivered to userland yet */
	u32 recv_data_len;
	/* The offset of the payload */
	u32 recv_data_off;

	/* Have we sent the zero-length packet (FIN)? */
	bool fin_sent;
};

/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
 * as the local cid.
 *
 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
 * the below sockaddr:
 *
 * struct SOCKADDR_HV
 * {
 *    ADDRESS_FAMILY Family;
 *    USHORT Reserved;
 *    GUID VmId;
 *    GUID ServiceId;
 * };
 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
 * VMBus, because here it's obvious the host and the VM can easily identify
 * each other. Though the VmID is useful on the host, especially in the case
 * of Windows container, Linux VM doesn't need it at all.
 *
 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
 * the available GUID space of SOCKADDR_HV so that we can create a mapping
 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
 * Hyper-V Sockets apps on the host and in Linux VM is:
 *
 ****************************************************************************
 * The only valid Service GUIDs, from the perspectives of both the host and *
 * Linux VM, that can be connected by the other end, must conform to this   *
 * format: <port>-facb-11e6-bd58-64006a7986d3, and the "port" must be in    *
 * this range [0, 0x7FFFFFFF].                                              *
 ****************************************************************************
 *
 * When we write apps on the host to connect(), the GUID ServiceID is used.
 * When we write apps in Linux VM to connect(), we only need to specify the
 * port and the driver will form the GUID and use that to request the host.
 *
 * From the perspective of Linux VM:
 * 1. the local ephemeral port (i.e. the local auto-bound port when we call
 * connect() without explicit bind()) is generated by __vsock_bind_stream(),
 * and the range is [1024, 0xFFFFFFFF).
 * 2. the remote ephemeral port (i.e. the auto-generated remote port for
 * a connect request initiated by the host's connect()) is generated by
 * hvs_remote_addr_init() and the range is [0x80000000, 0xFFFFFFFF).
 */

#define MAX_LISTEN_PORT			((u32)0x7FFFFFFF)
#define MAX_VM_LISTEN_PORT		MAX_LISTEN_PORT
#define MAX_HOST_LISTEN_PORT		MAX_LISTEN_PORT
#define MIN_HOST_EPHEMERAL_PORT		(MAX_HOST_LISTEN_PORT + 1)

/* 00000000-facb-11e6-bd58-64006a7986d3 */
static const uuid_le srv_id_template =
	UUID_LE(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
		0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);

static bool is_valid_srv_id(const uuid_le *id)
{
	return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(uuid_le) - 4);
}

static unsigned int get_port_by_srv_id(const uuid_le *svr_id)
{
	return *((unsigned int *)svr_id);
}

static void hvs_addr_init(struct sockaddr_vm *addr, const uuid_le *svr_id)
{
	unsigned int port = get_port_by_srv_id(svr_id);

	vsock_addr_init(addr, VMADDR_CID_ANY, port);
}

static void hvs_remote_addr_init(struct sockaddr_vm *remote,
				 struct sockaddr_vm *local)
{
	static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
	struct sock *sk;

	vsock_addr_init(remote, VMADDR_CID_ANY, VMADDR_PORT_ANY);

	while (1) {
		/* Wrap around ? */
		if (host_ephemeral_port < MIN_HOST_EPHEMERAL_PORT ||
		    host_ephemeral_port == VMADDR_PORT_ANY)
			host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;

		remote->svm_port = host_ephemeral_port++;

		sk = vsock_find_connected_socket(remote, local);
		if (!sk) {
			/* Found an available ephemeral port */
			return;
		}

		/* Release refcnt got in vsock_find_connected_socket */
		sock_put(sk);
	}
}

static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
{
	set_channel_pending_send_size(chan,
				      HVS_PKT_LEN(HVS_SEND_BUF_SIZE));

	/* See hvs_stream_has_space(): we must make sure the host has seen
	 * the new pending send size, before we can re-check the writable
	 * bytes.
	 */
	virt_mb();
}

static void hvs_clear_channel_pending_send_size(struct vmbus_channel *chan)
{
	set_channel_pending_send_size(chan, 0);

	/* Ditto */
	virt_mb();
}

static bool hvs_channel_readable(struct vmbus_channel *chan)
{
	u32 readable = hv_get_bytes_to_read(&chan->inbound);

	/* 0-size payload means FIN */
	return readable >= HVS_PKT_LEN(0);
}

static int hvs_channel_readable_payload(struct vmbus_channel *chan)
{
	u32 readable = hv_get_bytes_to_read(&chan->inbound);

	if (readable > HVS_PKT_LEN(0)) {
		/* At least we have 1 byte to read. We don't need to return
		 * the exact readable bytes: see vsock_stream_recvmsg() ->
		 * vsock_stream_has_data().
		 */
		return 1;
	}

	if (readable == HVS_PKT_LEN(0)) {
		/* 0-size payload means FIN */
		return 0;
	}

	/* No payload or FIN */
	return -1;
}

static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
{
	u32 writeable = hv_get_bytes_to_write(&chan->outbound);
	size_t ret;

	/* The ringbuffer mustn't be 100% full, and we should reserve a
	 * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
	 * and hvs_shutdown().
	 */
	if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
		return 0;

	ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);

	return round_down(ret, 8);
}

static int hvs_send_data(struct vmbus_channel *chan,
			 struct hvs_send_buf *send_buf, size_t to_write)
{
	send_buf->hdr.pkt_type = 1;
	send_buf->hdr.data_size = to_write;
	return vmbus_sendpacket(chan, &send_buf->hdr,
				sizeof(send_buf->hdr) + to_write,
				0, VM_PKT_DATA_INBAND, 0);
}

static void hvs_channel_cb(void *ctx)
{
	struct sock *sk = (struct sock *)ctx;
	struct vsock_sock *vsk = vsock_sk(sk);
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;

	if (hvs_channel_readable(chan))
		sk->sk_data_ready(sk);

	/* See hvs_stream_has_space(): when we reach here, the writable bytes
	 * may be already less than HVS_PKT_LEN(HVS_SEND_BUF_SIZE).
	 */
	if (hv_get_bytes_to_write(&chan->outbound) > 0)
		sk->sk_write_space(sk);
}

312 313
static void hvs_do_close_lock_held(struct vsock_sock *vsk,
				   bool cancel_timeout)
314
{
315
	struct sock *sk = sk_vsock(vsk);
316

317
	sock_set_flag(sk, SOCK_DONE);
318 319 320
	vsk->peer_shutdown = SHUTDOWN_MASK;
	if (vsock_stream_has_data(vsk) <= 0)
		sk->sk_state = TCP_CLOSING;
321
	sk->sk_state_change(sk);
322 323 324 325
	if (vsk->close_work_scheduled &&
	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
		vsk->close_work_scheduled = false;
		vsock_remove_sock(vsk);
326

327 328 329 330 331 332 333 334 335 336 337
		/* Release the reference taken while scheduling the timeout */
		sock_put(sk);
	}
}

static void hvs_close_connection(struct vmbus_channel *chan)
{
	struct sock *sk = get_per_channel_state(chan);

	lock_sock(sk);
	hvs_do_close_lock_held(vsock_sk(sk), true);
338
	release_sock(sk);
339 340 341 342 343 344 345 346 347
}

static void hvs_open_connection(struct vmbus_channel *chan)
{
	uuid_le *if_instance, *if_type;
	unsigned char conn_from_host;

	struct sockaddr_vm addr;
	struct sock *sk, *new = NULL;
348 349 350 351
	struct vsock_sock *vnew = NULL;
	struct hvsock *hvs = NULL;
	struct hvsock *hvs_new = NULL;
	int rcvbuf;
352
	int ret;
353
	int sndbuf;
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

	if_type = &chan->offermsg.offer.if_type;
	if_instance = &chan->offermsg.offer.if_instance;
	conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];

	/* The host or the VM should only listen on a port in
	 * [0, MAX_LISTEN_PORT]
	 */
	if (!is_valid_srv_id(if_type) ||
	    get_port_by_srv_id(if_type) > MAX_LISTEN_PORT)
		return;

	hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
	sk = vsock_find_bound_socket(&addr);
	if (!sk)
		return;

371
	lock_sock(sk);
372 373
	if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
	    (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
374 375 376 377 378 379 380 381 382 383 384
		goto out;

	if (conn_from_host) {
		if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
			goto out;

		new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
				     sk->sk_type, 0);
		if (!new)
			goto out;

385
		new->sk_state = TCP_SYN_SENT;
386 387 388 389 390 391 392 393 394
		vnew = vsock_sk(new);
		hvs_new = vnew->trans;
		hvs_new->chan = chan;
	} else {
		hvs = vsock_sk(sk)->trans;
		hvs->chan = chan;
	}

	set_channel_read_mode(chan, HV_CALL_DIRECT);
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422

	/* Use the socket buffer sizes as hints for the VMBUS ring size. For
	 * server side sockets, 'sk' is the parent socket and thus, this will
	 * allow the child sockets to inherit the size from the parent. Keep
	 * the mins to the default value and align to page size as per VMBUS
	 * requirements.
	 * For the max, the socket core library will limit the socket buffer
	 * size that can be set by the user, but, since currently, the hv_sock
	 * VMBUS ring buffer is physically contiguous allocation, restrict it
	 * further.
	 * Older versions of hv_sock host side code cannot handle bigger VMBUS
	 * ring buffer size. Use the version number to limit the change to newer
	 * versions.
	 */
	if (vmbus_proto_version < VERSION_WIN10_V5) {
		sndbuf = RINGBUFFER_HVS_SND_SIZE;
		rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
	} else {
		sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
		sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
		sndbuf = ALIGN(sndbuf, PAGE_SIZE);
		rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
		rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
		rcvbuf = ALIGN(rcvbuf, PAGE_SIZE);
	}

	ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
			 conn_from_host ? new : sk);
423 424 425 426 427 428 429 430 431 432 433 434 435 436
	if (ret != 0) {
		if (conn_from_host) {
			hvs_new->chan = NULL;
			sock_put(new);
		} else {
			hvs->chan = NULL;
		}
		goto out;
	}

	set_per_channel_state(chan, conn_from_host ? new : sk);
	vmbus_set_chn_rescind_callback(chan, hvs_close_connection);

	if (conn_from_host) {
437
		new->sk_state = TCP_ESTABLISHED;
438 439 440 441 442 443 444 445 446 447 448 449
		sk->sk_ack_backlog++;

		hvs_addr_init(&vnew->local_addr, if_type);
		hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);

		hvs_new->vm_srv_id = *if_type;
		hvs_new->host_srv_id = *if_instance;

		vsock_insert_connected(vnew);

		vsock_enqueue_accept(sk, new);
	} else {
450
		sk->sk_state = TCP_ESTABLISHED;
451 452 453 454 455 456 457 458 459 460
		sk->sk_socket->state = SS_CONNECTED;

		vsock_insert_connected(vsock_sk(sk));
	}

	sk->sk_state_change(sk);

out:
	/* Release refcnt obtained when we called vsock_find_bound_socket() */
	sock_put(sk);
461 462

	release_sock(sk);
463 464 465 466 467 468 469 470 471 472
}

static u32 hvs_get_local_cid(void)
{
	return VMADDR_CID_ANY;
}

static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
{
	struct hvsock *hvs;
473
	struct sock *sk = sk_vsock(vsk);
474 475 476 477 478 479 480

	hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
	if (!hvs)
		return -ENOMEM;

	vsk->trans = hvs;
	hvs->vsk = vsk;
481 482
	sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE;
	sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
	return 0;
}

static int hvs_connect(struct vsock_sock *vsk)
{
	union hvs_service_id vm, host;
	struct hvsock *h = vsk->trans;

	vm.srv_id = srv_id_template;
	vm.svm_port = vsk->local_addr.svm_port;
	h->vm_srv_id = vm.srv_id;

	host.srv_id = srv_id_template;
	host.svm_port = vsk->remote_addr.svm_port;
	h->host_srv_id = host.srv_id;

	return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
}

502 503 504 505 506 507 508 509 510 511 512 513
static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
{
	struct vmpipe_proto_header hdr;

	if (hvs->fin_sent || !hvs->chan)
		return;

	/* It can't fail: see hvs_channel_writable_bytes(). */
	(void)hvs_send_data(hvs->chan, (struct hvs_send_buf *)&hdr, 0);
	hvs->fin_sent = true;
}

514 515 516 517 518 519 520 521
static int hvs_shutdown(struct vsock_sock *vsk, int mode)
{
	struct sock *sk = sk_vsock(vsk);

	if (!(mode & SEND_SHUTDOWN))
		return 0;

	lock_sock(sk);
522 523 524 525
	hvs_shutdown_lock_held(vsk->trans, mode);
	release_sock(sk);
	return 0;
}
526

527 528 529 530 531
static void hvs_close_timeout(struct work_struct *work)
{
	struct vsock_sock *vsk =
		container_of(work, struct vsock_sock, close_work.work);
	struct sock *sk = sk_vsock(vsk);
532

533 534 535 536
	sock_hold(sk);
	lock_sock(sk);
	if (!sock_flag(sk, SOCK_DONE))
		hvs_do_close_lock_held(vsk, false);
537

538
	vsk->close_work_scheduled = false;
539
	release_sock(sk);
540
	sock_put(sk);
541 542
}

543 544
/* Returns true, if it is safe to remove socket; false otherwise */
static bool hvs_close_lock_held(struct vsock_sock *vsk)
545
{
546
	struct sock *sk = sk_vsock(vsk);
547

548 549 550
	if (!(sk->sk_state == TCP_ESTABLISHED ||
	      sk->sk_state == TCP_CLOSING))
		return true;
551

552 553
	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
		hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);
554

555 556
	if (sock_flag(sk, SOCK_DONE))
		return true;
557

558 559 560 561 562 563 564
	/* This reference will be dropped by the delayed close routine */
	sock_hold(sk);
	INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
	vsk->close_work_scheduled = true;
	schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
	return false;
}
565

566 567 568 569 570 571 572 573 574 575
static void hvs_release(struct vsock_sock *vsk)
{
	struct sock *sk = sk_vsock(vsk);
	bool remove_sock;

	lock_sock(sk);
	remove_sock = hvs_close_lock_held(vsk);
	release_sock(sk);
	if (remove_sock)
		vsock_remove_sock(vsk);
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677
}

static void hvs_destruct(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;

	if (chan)
		vmbus_hvsock_device_unregister(chan);

	kfree(hvs);
}

static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
{
	return -EOPNOTSUPP;
}

static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
			     size_t len, int flags)
{
	return -EOPNOTSUPP;
}

static int hvs_dgram_enqueue(struct vsock_sock *vsk,
			     struct sockaddr_vm *remote, struct msghdr *msg,
			     size_t dgram_len)
{
	return -EOPNOTSUPP;
}

static bool hvs_dgram_allow(u32 cid, u32 port)
{
	return false;
}

static int hvs_update_recv_data(struct hvsock *hvs)
{
	struct hvs_recv_buf *recv_buf;
	u32 payload_len;

	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
	payload_len = recv_buf->hdr.data_size;

	if (payload_len > HVS_MTU_SIZE)
		return -EIO;

	if (payload_len == 0)
		hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;

	hvs->recv_data_len = payload_len;
	hvs->recv_data_off = 0;

	return 0;
}

static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
				  size_t len, int flags)
{
	struct hvsock *hvs = vsk->trans;
	bool need_refill = !hvs->recv_desc;
	struct hvs_recv_buf *recv_buf;
	u32 to_read;
	int ret;

	if (flags & MSG_PEEK)
		return -EOPNOTSUPP;

	if (need_refill) {
		hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
		ret = hvs_update_recv_data(hvs);
		if (ret)
			return ret;
	}

	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
	to_read = min_t(u32, len, hvs->recv_data_len);
	ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
	if (ret != 0)
		return ret;

	hvs->recv_data_len -= to_read;
	if (hvs->recv_data_len == 0) {
		hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
		if (hvs->recv_desc) {
			ret = hvs_update_recv_data(hvs);
			if (ret)
				return ret;
		}
	} else {
		hvs->recv_data_off += to_read;
	}

	return to_read;
}

static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
				  size_t len)
{
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;
	struct hvs_send_buf *send_buf;
678 679 680
	ssize_t to_write, max_writable;
	ssize_t ret = 0;
	ssize_t bytes_written = 0;
681 682 683 684 685 686 687

	BUILD_BUG_ON(sizeof(*send_buf) != PAGE_SIZE_4K);

	send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
	if (!send_buf)
		return -ENOMEM;

688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
	/* Reader(s) could be draining data from the channel as we write.
	 * Maximize bandwidth, by iterating until the channel is found to be
	 * full.
	 */
	while (len) {
		max_writable = hvs_channel_writable_bytes(chan);
		if (!max_writable)
			break;
		to_write = min_t(ssize_t, len, max_writable);
		to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
		/* memcpy_from_msg is safe for loop as it advances the offsets
		 * within the message iterator.
		 */
		ret = memcpy_from_msg(send_buf->data, msg, to_write);
		if (ret < 0)
			goto out;
704

705 706 707
		ret = hvs_send_data(hvs->chan, send_buf, to_write);
		if (ret < 0)
			goto out;
708

709 710 711
		bytes_written += to_write;
		len -= to_write;
	}
712
out:
713 714 715
	/* If any data has been sent, return that */
	if (bytes_written)
		ret = bytes_written;
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
	kfree(send_buf);
	return ret;
}

static s64 hvs_stream_has_data(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;
	s64 ret;

	if (hvs->recv_data_len > 0)
		return 1;

	switch (hvs_channel_readable_payload(hvs->chan)) {
	case 1:
		ret = 1;
		break;
	case 0:
		vsk->peer_shutdown |= SEND_SHUTDOWN;
		ret = 0;
		break;
	default: /* -1 */
		ret = 0;
		break;
	}

	return ret;
}

static s64 hvs_stream_has_space(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;
	s64 ret;

	ret = hvs_channel_writable_bytes(chan);
	if (ret > 0)  {
		hvs_clear_channel_pending_send_size(chan);
	} else {
		/* See hvs_channel_cb() */
		hvs_set_channel_pending_send_size(chan);

		/* Re-check the writable bytes to avoid race */
		ret = hvs_channel_writable_bytes(chan);
		if (ret > 0)
			hvs_clear_channel_pending_send_size(chan);
	}

	return ret;
}

static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
{
	return HVS_MTU_SIZE + 1;
}

static bool hvs_stream_is_active(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;

	return hvs->chan != NULL;
}

static bool hvs_stream_allow(u32 cid, u32 port)
{
	/* The host's port range [MIN_HOST_EPHEMERAL_PORT, 0xFFFFFFFF) is
	 * reserved as ephemeral ports, which are used as the host's ports
	 * when the host initiates connections.
	 *
	 * Perform this check in the guest so an immediate error is produced
	 * instead of a timeout.
	 */
	if (port > MAX_HOST_LISTEN_PORT)
		return false;

	if (cid == VMADDR_CID_HOST)
		return true;

	return false;
}

static
int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
{
	struct hvsock *hvs = vsk->trans;

	*readable = hvs_channel_readable(hvs->chan);
	return 0;
}

static
int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
{
	*writable = hvs_stream_has_space(vsk) > 0;

	return 0;
}

static
int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
			 struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
			      struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
				struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
				 ssize_t copied, bool data_read,
				 struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_init(struct vsock_sock *vsk,
			 struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_pre_block(struct vsock_sock *vsk,
			      struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
				struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
				 struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val)
{
	/* Ignored. */
}

static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
{
	/* Ignored. */
}

static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
{
	/* Ignored. */
}

static u64 hvs_get_buffer_size(struct vsock_sock *vsk)
{
	return -ENOPROTOOPT;
}

static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk)
{
	return -ENOPROTOOPT;
}

static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk)
{
	return -ENOPROTOOPT;
}

static struct vsock_transport hvs_transport = {
	.get_local_cid            = hvs_get_local_cid,

	.init                     = hvs_sock_init,
	.destruct                 = hvs_destruct,
	.release                  = hvs_release,
	.connect                  = hvs_connect,
	.shutdown                 = hvs_shutdown,

	.dgram_bind               = hvs_dgram_bind,
	.dgram_dequeue            = hvs_dgram_dequeue,
	.dgram_enqueue            = hvs_dgram_enqueue,
	.dgram_allow              = hvs_dgram_allow,

	.stream_dequeue           = hvs_stream_dequeue,
	.stream_enqueue           = hvs_stream_enqueue,
	.stream_has_data          = hvs_stream_has_data,
	.stream_has_space         = hvs_stream_has_space,
	.stream_rcvhiwat          = hvs_stream_rcvhiwat,
	.stream_is_active         = hvs_stream_is_active,
	.stream_allow             = hvs_stream_allow,

	.notify_poll_in           = hvs_notify_poll_in,
	.notify_poll_out          = hvs_notify_poll_out,
	.notify_recv_init         = hvs_notify_recv_init,
	.notify_recv_pre_block    = hvs_notify_recv_pre_block,
	.notify_recv_pre_dequeue  = hvs_notify_recv_pre_dequeue,
	.notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
	.notify_send_init         = hvs_notify_send_init,
	.notify_send_pre_block    = hvs_notify_send_pre_block,
	.notify_send_pre_enqueue  = hvs_notify_send_pre_enqueue,
	.notify_send_post_enqueue = hvs_notify_send_post_enqueue,

	.set_buffer_size          = hvs_set_buffer_size,
	.set_min_buffer_size      = hvs_set_min_buffer_size,
	.set_max_buffer_size      = hvs_set_max_buffer_size,
	.get_buffer_size          = hvs_get_buffer_size,
	.get_min_buffer_size      = hvs_get_min_buffer_size,
	.get_max_buffer_size      = hvs_get_max_buffer_size,
};

static int hvs_probe(struct hv_device *hdev,
		     const struct hv_vmbus_device_id *dev_id)
{
	struct vmbus_channel *chan = hdev->channel;

	hvs_open_connection(chan);

	/* Always return success to suppress the unnecessary error message
	 * in vmbus_probe(): on error the host will rescind the device in
	 * 30 seconds and we can do cleanup at that time in
	 * vmbus_onoffer_rescind().
	 */
	return 0;
}

static int hvs_remove(struct hv_device *hdev)
{
	struct vmbus_channel *chan = hdev->channel;

	vmbus_close(chan);

	return 0;
}

/* This isn't really used. See vmbus_match() and vmbus_probe() */
static const struct hv_vmbus_device_id id_table[] = {
	{},
};

static struct hv_driver hvs_drv = {
	.name		= "hv_sock",
	.hvsock		= true,
	.id_table	= id_table,
	.probe		= hvs_probe,
	.remove		= hvs_remove,
};

static int __init hvs_init(void)
{
	int ret;

	if (vmbus_proto_version < VERSION_WIN10)
		return -ENODEV;

	ret = vmbus_driver_register(&hvs_drv);
	if (ret != 0)
		return ret;

	ret = vsock_core_init(&hvs_transport);
	if (ret) {
		vmbus_driver_unregister(&hvs_drv);
		return ret;
	}

	return 0;
}

static void __exit hvs_exit(void)
{
	vsock_core_exit();
	vmbus_driver_unregister(&hvs_drv);
}

module_init(hvs_init);
module_exit(hvs_exit);

MODULE_DESCRIPTION("Hyper-V Sockets");
MODULE_VERSION("1.0.0");
MODULE_LICENSE("GPL");
MODULE_ALIAS_NETPROTO(PF_VSOCK);