vhost-user.c 25.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*
 * vhost-user
 *
 * Copyright (c) 2013 Virtual Open Systems Sarl.
 *
 * This work is licensed under the terms of the GNU GPL, version 2 or later.
 * See the COPYING file in the top-level directory.
 *
 */

P
Peter Maydell 已提交
11
#include "qemu/osdep.h"
12
#include "qapi/error.h"
13 14
#include "hw/virtio/vhost.h"
#include "hw/virtio/vhost-backend.h"
15
#include "hw/virtio/virtio-net.h"
16 17 18 19 20 21 22 23 24 25 26
#include "sysemu/char.h"
#include "sysemu/kvm.h"
#include "qemu/error-report.h"
#include "qemu/sockets.h"

#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <linux/vhost.h>

#define VHOST_MEMORY_MAX_NREGIONS    8
27
#define VHOST_USER_F_PROTOCOL_FEATURES 30
28

29 30 31 32
enum VhostUserProtocolFeature {
    VHOST_USER_PROTOCOL_F_MQ = 0,
    VHOST_USER_PROTOCOL_F_LOG_SHMFD = 1,
    VHOST_USER_PROTOCOL_F_RARP = 2,
33
    VHOST_USER_PROTOCOL_F_REPLY_ACK = 3,
34
    VHOST_USER_PROTOCOL_F_NET_MTU = 4,
35
    VHOST_USER_PROTOCOL_F_SLAVE_REQ = 5,
36 37 38 39 40

    VHOST_USER_PROTOCOL_F_MAX
};

#define VHOST_USER_PROTOCOL_FEATURE_MASK ((1 << VHOST_USER_PROTOCOL_F_MAX) - 1)
41 42 43 44 45 46

typedef enum VhostUserRequest {
    VHOST_USER_NONE = 0,
    VHOST_USER_GET_FEATURES = 1,
    VHOST_USER_SET_FEATURES = 2,
    VHOST_USER_SET_OWNER = 3,
47
    VHOST_USER_RESET_OWNER = 4,
48 49 50 51 52 53 54 55 56 57
    VHOST_USER_SET_MEM_TABLE = 5,
    VHOST_USER_SET_LOG_BASE = 6,
    VHOST_USER_SET_LOG_FD = 7,
    VHOST_USER_SET_VRING_NUM = 8,
    VHOST_USER_SET_VRING_ADDR = 9,
    VHOST_USER_SET_VRING_BASE = 10,
    VHOST_USER_GET_VRING_BASE = 11,
    VHOST_USER_SET_VRING_KICK = 12,
    VHOST_USER_SET_VRING_CALL = 13,
    VHOST_USER_SET_VRING_ERR = 14,
58 59
    VHOST_USER_GET_PROTOCOL_FEATURES = 15,
    VHOST_USER_SET_PROTOCOL_FEATURES = 16,
60
    VHOST_USER_GET_QUEUE_NUM = 17,
61
    VHOST_USER_SET_VRING_ENABLE = 18,
62
    VHOST_USER_SEND_RARP = 19,
63
    VHOST_USER_NET_SET_MTU = 20,
64
    VHOST_USER_SET_SLAVE_REQ_FD = 21,
65
    VHOST_USER_IOTLB_MSG = 22,
66 67 68
    VHOST_USER_MAX
} VhostUserRequest;

69 70
typedef enum VhostUserSlaveRequest {
    VHOST_USER_SLAVE_NONE = 0,
71
    VHOST_USER_SLAVE_IOTLB_MSG = 1,
72 73 74
    VHOST_USER_SLAVE_MAX
}  VhostUserSlaveRequest;

75 76 77 78
typedef struct VhostUserMemoryRegion {
    uint64_t guest_phys_addr;
    uint64_t memory_size;
    uint64_t userspace_addr;
79
    uint64_t mmap_offset;
80 81 82 83 84 85 86 87
} VhostUserMemoryRegion;

typedef struct VhostUserMemory {
    uint32_t nregions;
    uint32_t padding;
    VhostUserMemoryRegion regions[VHOST_MEMORY_MAX_NREGIONS];
} VhostUserMemory;

88 89 90 91 92
typedef struct VhostUserLog {
    uint64_t mmap_size;
    uint64_t mmap_offset;
} VhostUserLog;

93 94 95 96 97
typedef struct VhostUserMsg {
    VhostUserRequest request;

#define VHOST_USER_VERSION_MASK     (0x3)
#define VHOST_USER_REPLY_MASK       (0x1<<2)
98
#define VHOST_USER_NEED_REPLY_MASK  (0x1 << 3)
99 100 101 102 103 104 105 106 107
    uint32_t flags;
    uint32_t size; /* the following payload size */
    union {
#define VHOST_USER_VRING_IDX_MASK   (0xff)
#define VHOST_USER_VRING_NOFD_MASK  (0x1<<8)
        uint64_t u64;
        struct vhost_vring_state state;
        struct vhost_vring_addr addr;
        VhostUserMemory memory;
108
        VhostUserLog log;
109
        struct vhost_iotlb_msg iotlb;
110
    } payload;
111 112 113 114 115 116 117 118 119 120 121 122
} QEMU_PACKED VhostUserMsg;

static VhostUserMsg m __attribute__ ((unused));
#define VHOST_USER_HDR_SIZE (sizeof(m.request) \
                            + sizeof(m.flags) \
                            + sizeof(m.size))

#define VHOST_USER_PAYLOAD_SIZE (sizeof(m) - VHOST_USER_HDR_SIZE)

/* The version of the protocol we support */
#define VHOST_USER_VERSION    (0x1)

123 124
struct vhost_user {
    CharBackend *chr;
125
    int slave_fd;
126 127
};

128 129 130 131 132 133 134
static bool ioeventfd_enabled(void)
{
    return kvm_enabled() && kvm_eventfds_enabled();
}

static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
{
135 136
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
137 138 139 140 141
    uint8_t *p = (uint8_t *) msg;
    int r, size = VHOST_USER_HDR_SIZE;

    r = qemu_chr_fe_read_all(chr, p, size);
    if (r != size) {
142 143
        error_report("Failed to read msg header. Read %d instead of %d."
                     " Original request %d.", r, size, msg->request);
144 145 146 147 148 149
        goto fail;
    }

    /* validate received flags */
    if (msg->flags != (VHOST_USER_REPLY_MASK | VHOST_USER_VERSION)) {
        error_report("Failed to read msg header."
150
                " Flags 0x%x instead of 0x%x.", msg->flags,
151 152 153 154 155 156 157
                VHOST_USER_REPLY_MASK | VHOST_USER_VERSION);
        goto fail;
    }

    /* validate message size is sane */
    if (msg->size > VHOST_USER_PAYLOAD_SIZE) {
        error_report("Failed to read msg header."
158
                " Size %d exceeds the maximum %zu.", msg->size,
159 160 161 162 163 164 165 166 167 168
                VHOST_USER_PAYLOAD_SIZE);
        goto fail;
    }

    if (msg->size) {
        p += VHOST_USER_HDR_SIZE;
        size = msg->size;
        r = qemu_chr_fe_read_all(chr, p, size);
        if (r != size) {
            error_report("Failed to read msg payload."
169
                         " Read %d instead of %d.", r, msg->size);
170 171 172 173 174 175 176 177 178 179
            goto fail;
        }
    }

    return 0;

fail:
    return -1;
}

180
static int process_message_reply(struct vhost_dev *dev,
181
                                 const VhostUserMsg *msg)
182
{
183
    VhostUserMsg msg_reply;
184

185
    if ((msg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
186 187 188 189
        return 0;
    }

    if (vhost_user_read(dev, &msg_reply) < 0) {
190 191 192
        return -1;
    }

193
    if (msg_reply.request != msg->request) {
194 195
        error_report("Received unexpected msg type."
                     "Expected %d received %d",
196
                     msg->request, msg_reply.request);
197 198 199
        return -1;
    }

200
    return msg_reply.payload.u64 ? -1 : 0;
201 202
}

203 204 205 206
static bool vhost_user_one_time_request(VhostUserRequest request)
{
    switch (request) {
    case VHOST_USER_SET_OWNER:
207
    case VHOST_USER_RESET_OWNER:
208 209
    case VHOST_USER_SET_MEM_TABLE:
    case VHOST_USER_GET_QUEUE_NUM:
210
    case VHOST_USER_NET_SET_MTU:
211 212 213 214 215 216 217
        return true;
    default:
        return false;
    }
}

/* most non-init callers ignore the error */
218 219 220
static int vhost_user_write(struct vhost_dev *dev, VhostUserMsg *msg,
                            int *fds, int fd_num)
{
221 222
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
223
    int ret, size = VHOST_USER_HDR_SIZE + msg->size;
224

225 226 227 228 229 230
    /*
     * For non-vring specific requests, like VHOST_USER_SET_MEM_TABLE,
     * we just need send it once in the first time. For later such
     * request, we just ignore it.
     */
    if (vhost_user_one_time_request(msg->request) && dev->vq_index != 0) {
231
        msg->flags &= ~VHOST_USER_NEED_REPLY_MASK;
232 233 234
        return 0;
    }

235
    if (qemu_chr_fe_set_msgfds(chr, fds, fd_num) < 0) {
236
        error_report("Failed to set msg fds.");
237 238
        return -1;
    }
239

240 241 242 243 244 245 246 247
    ret = qemu_chr_fe_write_all(chr, (const uint8_t *) msg, size);
    if (ret != size) {
        error_report("Failed to write msg."
                     " Wrote %d instead of %d.", ret, size);
        return -1;
    }

    return 0;
248 249
}

250 251
static int vhost_user_set_log_base(struct vhost_dev *dev, uint64_t base,
                                   struct vhost_log *log)
252
{
253 254 255 256 257 258 259
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    size_t fd_num = 0;
    bool shmfd = virtio_has_feature(dev->protocol_features,
                                    VHOST_USER_PROTOCOL_F_LOG_SHMFD);
    VhostUserMsg msg = {
        .request = VHOST_USER_SET_LOG_BASE,
        .flags = VHOST_USER_VERSION,
M
Michael S. Tsirkin 已提交
260
        .payload.log.mmap_size = log->size * sizeof(*(log->log)),
261 262
        .payload.log.mmap_offset = 0,
        .size = sizeof(msg.payload.log),
263 264 265 266 267 268
    };

    if (shmfd && log->fd != -1) {
        fds[fd_num++] = log->fd;
    }

269 270 271
    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }
272 273 274 275

    if (shmfd) {
        msg.size = 0;
        if (vhost_user_read(dev, &msg) < 0) {
276
            return -1;
277 278 279 280 281 282 283 284
        }

        if (msg.request != VHOST_USER_SET_LOG_BASE) {
            error_report("Received unexpected msg type. "
                         "Expected %d received %d",
                         VHOST_USER_SET_LOG_BASE, msg.request);
            return -1;
        }
285
    }
286 287

    return 0;
288 289
}

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
static int vhost_user_set_mem_table(struct vhost_dev *dev,
                                    struct vhost_memory *mem)
{
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    int i, fd;
    size_t fd_num = 0;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    VhostUserMsg msg = {
        .request = VHOST_USER_SET_MEM_TABLE,
        .flags = VHOST_USER_VERSION,
    };

    if (reply_supported) {
        msg.flags |= VHOST_USER_NEED_REPLY_MASK;
    }

    for (i = 0; i < dev->mem->nregions; ++i) {
        struct vhost_memory_region *reg = dev->mem->regions + i;
        ram_addr_t offset;
        MemoryRegion *mr;

        assert((uintptr_t)reg->userspace_addr == reg->userspace_addr);
        mr = memory_region_from_host((void *)(uintptr_t)reg->userspace_addr,
                                     &offset);
        fd = memory_region_get_fd(mr);
        if (fd > 0) {
            msg.payload.memory.regions[fd_num].userspace_addr = reg->userspace_addr;
            msg.payload.memory.regions[fd_num].memory_size  = reg->memory_size;
            msg.payload.memory.regions[fd_num].guest_phys_addr = reg->guest_phys_addr;
            msg.payload.memory.regions[fd_num].mmap_offset = offset;
            assert(fd_num < VHOST_MEMORY_MAX_NREGIONS);
            fds[fd_num++] = fd;
        }
    }

    msg.payload.memory.nregions = fd_num;

    if (!fd_num) {
        error_report("Failed initializing vhost-user memory map, "
                     "consider using -object memory-backend-file share=on");
        return -1;
    }

    msg.size = sizeof(msg.payload.memory.nregions);
    msg.size += sizeof(msg.payload.memory.padding);
    msg.size += fd_num * sizeof(VhostUserMemoryRegion);

    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }

    if (reply_supported) {
344
        return process_message_reply(dev, &msg);
345 346 347 348 349
    }

    return 0;
}

350 351 352 353 354 355
static int vhost_user_set_vring_addr(struct vhost_dev *dev,
                                     struct vhost_vring_addr *addr)
{
    VhostUserMsg msg = {
        .request = VHOST_USER_SET_VRING_ADDR,
        .flags = VHOST_USER_VERSION,
356
        .payload.addr = *addr,
357
        .size = sizeof(msg.payload.addr),
358
    };
359

360 361 362
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
363

364 365
    return 0;
}
366

367 368 369 370 371 372
static int vhost_user_set_vring_endian(struct vhost_dev *dev,
                                       struct vhost_vring_state *ring)
{
    error_report("vhost-user trying to send unhandled ioctl");
    return -1;
}
373

374 375 376 377 378 379 380
static int vhost_set_vring(struct vhost_dev *dev,
                           unsigned long int request,
                           struct vhost_vring_state *ring)
{
    VhostUserMsg msg = {
        .request = request,
        .flags = VHOST_USER_VERSION,
381
        .payload.state = *ring,
382
        .size = sizeof(msg.payload.state),
383 384
    };

385 386 387
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405

    return 0;
}

static int vhost_user_set_vring_num(struct vhost_dev *dev,
                                    struct vhost_vring_state *ring)
{
    return vhost_set_vring(dev, VHOST_USER_SET_VRING_NUM, ring);
}

static int vhost_user_set_vring_base(struct vhost_dev *dev,
                                     struct vhost_vring_state *ring)
{
    return vhost_set_vring(dev, VHOST_USER_SET_VRING_BASE, ring);
}

static int vhost_user_set_vring_enable(struct vhost_dev *dev, int enable)
{
406
    int i;
407

408
    if (!virtio_has_feature(dev->features, VHOST_USER_F_PROTOCOL_FEATURES)) {
409 410 411
        return -1;
    }

412 413 414 415 416 417 418 419
    for (i = 0; i < dev->nvqs; ++i) {
        struct vhost_vring_state state = {
            .index = dev->vq_index + i,
            .num   = enable,
        };

        vhost_set_vring(dev, VHOST_USER_SET_VRING_ENABLE, &state);
    }
420

421 422
    return 0;
}
423 424 425 426 427 428 429

static int vhost_user_get_vring_base(struct vhost_dev *dev,
                                     struct vhost_vring_state *ring)
{
    VhostUserMsg msg = {
        .request = VHOST_USER_GET_VRING_BASE,
        .flags = VHOST_USER_VERSION,
430
        .payload.state = *ring,
431
        .size = sizeof(msg.payload.state),
432 433
    };

434 435 436
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
437 438

    if (vhost_user_read(dev, &msg) < 0) {
439
        return -1;
440 441
    }

442 443 444 445 446
    if (msg.request != VHOST_USER_GET_VRING_BASE) {
        error_report("Received unexpected msg type. Expected %d received %d",
                     VHOST_USER_GET_VRING_BASE, msg.request);
        return -1;
    }
447

448
    if (msg.size != sizeof(msg.payload.state)) {
449 450
        error_report("Received bad msg size.");
        return -1;
451 452
    }

453
    *ring = msg.payload.state;
454

455 456 457
    return 0;
}

458 459 460
static int vhost_set_vring_file(struct vhost_dev *dev,
                                VhostUserRequest request,
                                struct vhost_vring_file *file)
461
{
462 463
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    size_t fd_num = 0;
464
    VhostUserMsg msg = {
465
        .request = request,
466
        .flags = VHOST_USER_VERSION,
467
        .payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK,
468
        .size = sizeof(msg.payload.u64),
469 470
    };

471 472 473
    if (ioeventfd_enabled() && file->fd > 0) {
        fds[fd_num++] = file->fd;
    } else {
474
        msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK;
475 476
    }

477 478 479
    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }
480

481 482
    return 0;
}
483

484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
static int vhost_user_set_vring_kick(struct vhost_dev *dev,
                                     struct vhost_vring_file *file)
{
    return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_KICK, file);
}

static int vhost_user_set_vring_call(struct vhost_dev *dev,
                                     struct vhost_vring_file *file)
{
    return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_CALL, file);
}

static int vhost_user_set_u64(struct vhost_dev *dev, int request, uint64_t u64)
{
    VhostUserMsg msg = {
        .request = request,
        .flags = VHOST_USER_VERSION,
501
        .payload.u64 = u64,
502
        .size = sizeof(msg.payload.u64),
503 504
    };

505 506 507
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532

    return 0;
}

static int vhost_user_set_features(struct vhost_dev *dev,
                                   uint64_t features)
{
    return vhost_user_set_u64(dev, VHOST_USER_SET_FEATURES, features);
}

static int vhost_user_set_protocol_features(struct vhost_dev *dev,
                                            uint64_t features)
{
    return vhost_user_set_u64(dev, VHOST_USER_SET_PROTOCOL_FEATURES, features);
}

static int vhost_user_get_u64(struct vhost_dev *dev, int request, uint64_t *u64)
{
    VhostUserMsg msg = {
        .request = request,
        .flags = VHOST_USER_VERSION,
    };

    if (vhost_user_one_time_request(request) && dev->vq_index != 0) {
        return 0;
533
    }
534

535 536 537
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
538 539

    if (vhost_user_read(dev, &msg) < 0) {
540
        return -1;
541 542 543 544 545 546 547 548
    }

    if (msg.request != request) {
        error_report("Received unexpected msg type. Expected %d received %d",
                     request, msg.request);
        return -1;
    }

549
    if (msg.size != sizeof(msg.payload.u64)) {
550 551 552 553
        error_report("Received bad msg size.");
        return -1;
    }

554
    *u64 = msg.payload.u64;
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570

    return 0;
}

static int vhost_user_get_features(struct vhost_dev *dev, uint64_t *features)
{
    return vhost_user_get_u64(dev, VHOST_USER_GET_FEATURES, features);
}

static int vhost_user_set_owner(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
        .request = VHOST_USER_SET_OWNER,
        .flags = VHOST_USER_VERSION,
    };

571 572 573
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
574 575 576 577 578 579 580

    return 0;
}

static int vhost_user_reset_device(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
581
        .request = VHOST_USER_RESET_OWNER,
582 583 584
        .flags = VHOST_USER_VERSION,
    };

585 586 587
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
588

589 590 591
    return 0;
}

592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
static void slave_read(void *opaque)
{
    struct vhost_dev *dev = opaque;
    struct vhost_user *u = dev->opaque;
    VhostUserMsg msg = { 0, };
    int size, ret = 0;

    /* Read header */
    size = read(u->slave_fd, &msg, VHOST_USER_HDR_SIZE);
    if (size != VHOST_USER_HDR_SIZE) {
        error_report("Failed to read from slave.");
        goto err;
    }

    if (msg.size > VHOST_USER_PAYLOAD_SIZE) {
        error_report("Failed to read msg header."
                " Size %d exceeds the maximum %zu.", msg.size,
                VHOST_USER_PAYLOAD_SIZE);
        goto err;
    }

    /* Read payload */
    size = read(u->slave_fd, &msg.payload, msg.size);
    if (size != msg.size) {
        error_report("Failed to read payload from slave.");
        goto err;
    }

    switch (msg.request) {
621 622 623
    case VHOST_USER_SLAVE_IOTLB_MSG:
        ret = vhost_backend_handle_iotlb_msg(dev, &msg.payload.iotlb);
        break;
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
    default:
        error_report("Received unexpected msg type.");
        ret = -EINVAL;
    }

    /*
     * REPLY_ACK feature handling. Other reply types has to be managed
     * directly in their request handlers.
     */
    if (msg.flags & VHOST_USER_NEED_REPLY_MASK) {
        msg.flags &= ~VHOST_USER_NEED_REPLY_MASK;
        msg.flags |= VHOST_USER_REPLY_MASK;

        msg.payload.u64 = !!ret;
        msg.size = sizeof(msg.payload.u64);

        size = write(u->slave_fd, &msg, VHOST_USER_HDR_SIZE + msg.size);
        if (size != VHOST_USER_HDR_SIZE + msg.size) {
            error_report("Failed to send msg reply to slave.");
            goto err;
        }
    }

    return;

err:
    qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
    close(u->slave_fd);
    u->slave_fd = -1;
    return;
}

static int vhost_setup_slave_channel(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
        .request = VHOST_USER_SET_SLAVE_REQ_FD,
        .flags = VHOST_USER_VERSION,
    };
    struct vhost_user *u = dev->opaque;
    int sv[2], ret = 0;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    if (!virtio_has_feature(dev->protocol_features,
                            VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
        return 0;
    }

    if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) == -1) {
        error_report("socketpair() failed");
        return -1;
    }

    u->slave_fd = sv[0];
    qemu_set_fd_handler(u->slave_fd, slave_read, NULL, dev);

    if (reply_supported) {
        msg.flags |= VHOST_USER_NEED_REPLY_MASK;
    }

    ret = vhost_user_write(dev, &msg, &sv[1], 1);
    if (ret) {
        goto out;
    }

    if (reply_supported) {
        ret = process_message_reply(dev, &msg);
    }

out:
    close(sv[1]);
    if (ret) {
        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
        close(u->slave_fd);
        u->slave_fd = -1;
    }

    return ret;
}

704 705
static int vhost_user_init(struct vhost_dev *dev, void *opaque)
{
706
    uint64_t features, protocol_features;
707
    struct vhost_user *u;
708 709
    int err;

710 711
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

712 713
    u = g_new0(struct vhost_user, 1);
    u->chr = opaque;
714
    u->slave_fd = -1;
715
    dev->opaque = u;
716

717
    err = vhost_user_get_features(dev, &features);
718 719 720 721 722 723 724
    if (err < 0) {
        return err;
    }

    if (virtio_has_feature(features, VHOST_USER_F_PROTOCOL_FEATURES)) {
        dev->backend_features |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;

725
        err = vhost_user_get_u64(dev, VHOST_USER_GET_PROTOCOL_FEATURES,
726
                                 &protocol_features);
727 728 729 730
        if (err < 0) {
            return err;
        }

731 732
        dev->protocol_features =
            protocol_features & VHOST_USER_PROTOCOL_FEATURE_MASK;
733
        err = vhost_user_set_protocol_features(dev, dev->protocol_features);
734 735 736
        if (err < 0) {
            return err;
        }
737 738 739

        /* query the max queues we support if backend supports Multiple Queue */
        if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_MQ)) {
740 741
            err = vhost_user_get_u64(dev, VHOST_USER_GET_QUEUE_NUM,
                                     &dev->max_queues);
742 743 744 745
            if (err < 0) {
                return err;
            }
        }
746 747 748 749 750 751 752 753 754 755

        if (virtio_has_feature(features, VIRTIO_F_IOMMU_PLATFORM) &&
                !(virtio_has_feature(dev->protocol_features,
                    VHOST_USER_PROTOCOL_F_SLAVE_REQ) &&
                 virtio_has_feature(dev->protocol_features,
                    VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
            error_report("IOMMU support requires reply-ack and "
                         "slave-req protocol features.");
            return -1;
        }
756 757
    }

758 759 760 761 762 763 764 765
    if (dev->migration_blocker == NULL &&
        !virtio_has_feature(dev->protocol_features,
                            VHOST_USER_PROTOCOL_F_LOG_SHMFD)) {
        error_setg(&dev->migration_blocker,
                   "Migration disabled: vhost-user backend lacks "
                   "VHOST_USER_PROTOCOL_F_LOG_SHMFD feature.");
    }

766 767 768 769 770
    err = vhost_setup_slave_channel(dev);
    if (err < 0) {
        return err;
    }

771 772 773 774 775
    return 0;
}

static int vhost_user_cleanup(struct vhost_dev *dev)
{
776 777
    struct vhost_user *u;

778 779
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

780
    u = dev->opaque;
781 782 783 784
    if (u->slave_fd >= 0) {
        close(u->slave_fd);
        u->slave_fd = -1;
    }
785
    g_free(u);
786 787 788 789 790
    dev->opaque = 0;

    return 0;
}

791 792 793 794 795 796 797
static int vhost_user_get_vq_index(struct vhost_dev *dev, int idx)
{
    assert(idx >= dev->vq_index && idx < dev->vq_index + dev->nvqs);

    return idx;
}

798 799 800 801 802
static int vhost_user_memslots_limit(struct vhost_dev *dev)
{
    return VHOST_MEMORY_MAX_NREGIONS;
}

803 804 805 806 807 808 809 810
static bool vhost_user_requires_shm_log(struct vhost_dev *dev)
{
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    return virtio_has_feature(dev->protocol_features,
                              VHOST_USER_PROTOCOL_F_LOG_SHMFD);
}

811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826
static int vhost_user_migration_done(struct vhost_dev *dev, char* mac_addr)
{
    VhostUserMsg msg = { 0 };

    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    /* If guest supports GUEST_ANNOUNCE do nothing */
    if (virtio_has_feature(dev->acked_features, VIRTIO_NET_F_GUEST_ANNOUNCE)) {
        return 0;
    }

    /* if backend supports VHOST_USER_PROTOCOL_F_RARP ask it to send the RARP */
    if (virtio_has_feature(dev->protocol_features,
                           VHOST_USER_PROTOCOL_F_RARP)) {
        msg.request = VHOST_USER_SEND_RARP;
        msg.flags = VHOST_USER_VERSION;
827
        memcpy((char *)&msg.payload.u64, mac_addr, 6);
828
        msg.size = sizeof(msg.payload.u64);
829

830
        return vhost_user_write(dev, &msg, NULL, 0);
831 832 833 834
    }
    return -1;
}

835 836 837 838
static bool vhost_user_can_merge(struct vhost_dev *dev,
                                 uint64_t start1, uint64_t size1,
                                 uint64_t start2, uint64_t size2)
{
839
    ram_addr_t offset;
840 841 842
    int mfd, rfd;
    MemoryRegion *mr;

843
    mr = memory_region_from_host((void *)(uintptr_t)start1, &offset);
844
    mfd = memory_region_get_fd(mr);
845

846
    mr = memory_region_from_host((void *)(uintptr_t)start2, &offset);
847
    rfd = memory_region_get_fd(mr);
848 849 850 851

    return mfd == rfd;
}

852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875
static int vhost_user_net_set_mtu(struct vhost_dev *dev, uint16_t mtu)
{
    VhostUserMsg msg;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    if (!(dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_NET_MTU))) {
        return 0;
    }

    msg.request = VHOST_USER_NET_SET_MTU;
    msg.payload.u64 = mtu;
    msg.size = sizeof(msg.payload.u64);
    msg.flags = VHOST_USER_VERSION;
    if (reply_supported) {
        msg.flags |= VHOST_USER_NEED_REPLY_MASK;
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    /* If reply_ack supported, slave has to ack specified MTU is valid */
    if (reply_supported) {
876
        return process_message_reply(dev, &msg);
877 878 879 880 881
    }

    return 0;
}

882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
static int vhost_user_send_device_iotlb_msg(struct vhost_dev *dev,
                                            struct vhost_iotlb_msg *imsg)
{
    VhostUserMsg msg = {
        .request = VHOST_USER_IOTLB_MSG,
        .size = sizeof(msg.payload.iotlb),
        .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
        .payload.iotlb = *imsg,
    };

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -EFAULT;
    }

    return process_message_reply(dev, &msg);
}


static void vhost_user_set_iotlb_callback(struct vhost_dev *dev, int enabled)
{
    /* No-op as the receive channel is not dedicated to IOTLB messages. */
}

905 906 907
const VhostOps user_ops = {
        .backend_type = VHOST_BACKEND_TYPE_USER,
        .vhost_backend_init = vhost_user_init,
908
        .vhost_backend_cleanup = vhost_user_cleanup,
909
        .vhost_backend_memslots_limit = vhost_user_memslots_limit,
910 911 912 913 914 915 916 917 918 919 920 921 922 923 924
        .vhost_set_log_base = vhost_user_set_log_base,
        .vhost_set_mem_table = vhost_user_set_mem_table,
        .vhost_set_vring_addr = vhost_user_set_vring_addr,
        .vhost_set_vring_endian = vhost_user_set_vring_endian,
        .vhost_set_vring_num = vhost_user_set_vring_num,
        .vhost_set_vring_base = vhost_user_set_vring_base,
        .vhost_get_vring_base = vhost_user_get_vring_base,
        .vhost_set_vring_kick = vhost_user_set_vring_kick,
        .vhost_set_vring_call = vhost_user_set_vring_call,
        .vhost_set_features = vhost_user_set_features,
        .vhost_get_features = vhost_user_get_features,
        .vhost_set_owner = vhost_user_set_owner,
        .vhost_reset_device = vhost_user_reset_device,
        .vhost_get_vq_index = vhost_user_get_vq_index,
        .vhost_set_vring_enable = vhost_user_set_vring_enable,
925
        .vhost_requires_shm_log = vhost_user_requires_shm_log,
926
        .vhost_migration_done = vhost_user_migration_done,
927
        .vhost_backend_can_merge = vhost_user_can_merge,
928
        .vhost_net_set_mtu = vhost_user_net_set_mtu,
929 930
        .vhost_set_iotlb_callback = vhost_user_set_iotlb_callback,
        .vhost_send_device_iotlb_msg = vhost_user_send_device_iotlb_msg,
931
};