vhost-user.c 49.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*
 * vhost-user
 *
 * Copyright (c) 2013 Virtual Open Systems Sarl.
 *
 * This work is licensed under the terms of the GNU GPL, version 2 or later.
 * See the COPYING file in the top-level directory.
 *
 */

P
Peter Maydell 已提交
11
#include "qemu/osdep.h"
12
#include "qapi/error.h"
13 14
#include "hw/virtio/vhost.h"
#include "hw/virtio/vhost-backend.h"
15
#include "hw/virtio/virtio-net.h"
16
#include "chardev/char-fe.h"
17 18 19
#include "sysemu/kvm.h"
#include "qemu/error-report.h"
#include "qemu/sockets.h"
20
#include "sysemu/cryptodev.h"
21 22
#include "migration/migration.h"
#include "migration/postcopy-ram.h"
23
#include "trace.h"
24 25 26 27 28

#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <linux/vhost.h>
29
#include <linux/userfaultfd.h>
30 31

#define VHOST_MEMORY_MAX_NREGIONS    8
32
#define VHOST_USER_F_PROTOCOL_FEATURES 30
33

34 35 36 37 38
/*
 * Maximum size of virtio device config space
 */
#define VHOST_USER_MAX_CONFIG_SIZE 256

39 40 41 42
enum VhostUserProtocolFeature {
    VHOST_USER_PROTOCOL_F_MQ = 0,
    VHOST_USER_PROTOCOL_F_LOG_SHMFD = 1,
    VHOST_USER_PROTOCOL_F_RARP = 2,
43
    VHOST_USER_PROTOCOL_F_REPLY_ACK = 3,
44
    VHOST_USER_PROTOCOL_F_NET_MTU = 4,
45
    VHOST_USER_PROTOCOL_F_SLAVE_REQ = 5,
46
    VHOST_USER_PROTOCOL_F_CROSS_ENDIAN = 6,
47
    VHOST_USER_PROTOCOL_F_CRYPTO_SESSION = 7,
48
    VHOST_USER_PROTOCOL_F_PAGEFAULT = 8,
49
    VHOST_USER_PROTOCOL_F_CONFIG = 9,
50 51 52 53
    VHOST_USER_PROTOCOL_F_MAX
};

#define VHOST_USER_PROTOCOL_FEATURE_MASK ((1 << VHOST_USER_PROTOCOL_F_MAX) - 1)
54 55 56 57 58 59

typedef enum VhostUserRequest {
    VHOST_USER_NONE = 0,
    VHOST_USER_GET_FEATURES = 1,
    VHOST_USER_SET_FEATURES = 2,
    VHOST_USER_SET_OWNER = 3,
60
    VHOST_USER_RESET_OWNER = 4,
61 62 63 64 65 66 67 68 69 70
    VHOST_USER_SET_MEM_TABLE = 5,
    VHOST_USER_SET_LOG_BASE = 6,
    VHOST_USER_SET_LOG_FD = 7,
    VHOST_USER_SET_VRING_NUM = 8,
    VHOST_USER_SET_VRING_ADDR = 9,
    VHOST_USER_SET_VRING_BASE = 10,
    VHOST_USER_GET_VRING_BASE = 11,
    VHOST_USER_SET_VRING_KICK = 12,
    VHOST_USER_SET_VRING_CALL = 13,
    VHOST_USER_SET_VRING_ERR = 14,
71 72
    VHOST_USER_GET_PROTOCOL_FEATURES = 15,
    VHOST_USER_SET_PROTOCOL_FEATURES = 16,
73
    VHOST_USER_GET_QUEUE_NUM = 17,
74
    VHOST_USER_SET_VRING_ENABLE = 18,
75
    VHOST_USER_SEND_RARP = 19,
76
    VHOST_USER_NET_SET_MTU = 20,
77
    VHOST_USER_SET_SLAVE_REQ_FD = 21,
78
    VHOST_USER_IOTLB_MSG = 22,
79
    VHOST_USER_SET_VRING_ENDIAN = 23,
80 81
    VHOST_USER_GET_CONFIG = 24,
    VHOST_USER_SET_CONFIG = 25,
82 83
    VHOST_USER_CREATE_CRYPTO_SESSION = 26,
    VHOST_USER_CLOSE_CRYPTO_SESSION = 27,
84
    VHOST_USER_POSTCOPY_ADVISE  = 28,
85
    VHOST_USER_POSTCOPY_LISTEN  = 29,
86
    VHOST_USER_POSTCOPY_END     = 30,
87 88 89
    VHOST_USER_MAX
} VhostUserRequest;

90 91
typedef enum VhostUserSlaveRequest {
    VHOST_USER_SLAVE_NONE = 0,
92
    VHOST_USER_SLAVE_IOTLB_MSG = 1,
93
    VHOST_USER_SLAVE_CONFIG_CHANGE_MSG = 2,
94 95 96
    VHOST_USER_SLAVE_MAX
}  VhostUserSlaveRequest;

97 98 99 100
typedef struct VhostUserMemoryRegion {
    uint64_t guest_phys_addr;
    uint64_t memory_size;
    uint64_t userspace_addr;
101
    uint64_t mmap_offset;
102 103 104 105 106 107 108 109
} VhostUserMemoryRegion;

typedef struct VhostUserMemory {
    uint32_t nregions;
    uint32_t padding;
    VhostUserMemoryRegion regions[VHOST_MEMORY_MAX_NREGIONS];
} VhostUserMemory;

110 111 112 113 114
typedef struct VhostUserLog {
    uint64_t mmap_size;
    uint64_t mmap_offset;
} VhostUserLog;

115 116 117 118 119 120 121
typedef struct VhostUserConfig {
    uint32_t offset;
    uint32_t size;
    uint32_t flags;
    uint8_t region[VHOST_USER_MAX_CONFIG_SIZE];
} VhostUserConfig;

122 123 124 125 126 127 128 129 130 131 132
#define VHOST_CRYPTO_SYM_HMAC_MAX_KEY_LEN    512
#define VHOST_CRYPTO_SYM_CIPHER_MAX_KEY_LEN  64

typedef struct VhostUserCryptoSession {
    /* session id for success, -1 on errors */
    int64_t session_id;
    CryptoDevBackendSymSessionInfo session_setup_data;
    uint8_t key[VHOST_CRYPTO_SYM_CIPHER_MAX_KEY_LEN];
    uint8_t auth_key[VHOST_CRYPTO_SYM_HMAC_MAX_KEY_LEN];
} VhostUserCryptoSession;

133 134 135 136 137
static VhostUserConfig c __attribute__ ((unused));
#define VHOST_USER_CONFIG_HDR_SIZE (sizeof(c.offset) \
                                   + sizeof(c.size) \
                                   + sizeof(c.flags))

138
typedef struct {
139 140 141 142
    VhostUserRequest request;

#define VHOST_USER_VERSION_MASK     (0x3)
#define VHOST_USER_REPLY_MASK       (0x1<<2)
143
#define VHOST_USER_NEED_REPLY_MASK  (0x1 << 3)
144 145
    uint32_t flags;
    uint32_t size; /* the following payload size */
146 147 148
} QEMU_PACKED VhostUserHeader;

typedef union {
149 150 151 152 153 154
#define VHOST_USER_VRING_IDX_MASK   (0xff)
#define VHOST_USER_VRING_NOFD_MASK  (0x1<<8)
        uint64_t u64;
        struct vhost_vring_state state;
        struct vhost_vring_addr addr;
        VhostUserMemory memory;
155
        VhostUserLog log;
156
        struct vhost_iotlb_msg iotlb;
157
        VhostUserConfig config;
158
        VhostUserCryptoSession session;
159 160 161 162 163
} VhostUserPayload;

typedef struct VhostUserMsg {
    VhostUserHeader hdr;
    VhostUserPayload payload;
164 165 166
} QEMU_PACKED VhostUserMsg;

static VhostUserMsg m __attribute__ ((unused));
167
#define VHOST_USER_HDR_SIZE (sizeof(VhostUserHeader))
168

169
#define VHOST_USER_PAYLOAD_SIZE (sizeof(VhostUserPayload))
170 171 172 173

/* The version of the protocol we support */
#define VHOST_USER_VERSION    (0x1)

174
struct vhost_user {
175
    struct vhost_dev *dev;
176
    CharBackend *chr;
177
    int slave_fd;
178
    NotifierWithReturn postcopy_notifier;
179
    struct PostCopyFD  postcopy_fd;
180
    uint64_t           postcopy_client_bases[VHOST_MEMORY_MAX_NREGIONS];
181 182 183 184 185 186 187 188 189
    /* Length of the region_rb and region_rb_offset arrays */
    size_t             region_rb_len;
    /* RAMBlock associated with a given region */
    RAMBlock         **region_rb;
    /* The offset from the start of the RAMBlock to the start of the
     * vhost region.
     */
    ram_addr_t        *region_rb_offset;

190 191
    /* True once we've entered postcopy_listen */
    bool               postcopy_listen;
192 193
};

194 195 196 197 198 199 200
static bool ioeventfd_enabled(void)
{
    return kvm_enabled() && kvm_eventfds_enabled();
}

static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
{
201 202
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
203 204 205 206 207
    uint8_t *p = (uint8_t *) msg;
    int r, size = VHOST_USER_HDR_SIZE;

    r = qemu_chr_fe_read_all(chr, p, size);
    if (r != size) {
208
        error_report("Failed to read msg header. Read %d instead of %d."
209
                     " Original request %d.", r, size, msg->hdr.request);
210 211 212 213
        goto fail;
    }

    /* validate received flags */
214
    if (msg->hdr.flags != (VHOST_USER_REPLY_MASK | VHOST_USER_VERSION)) {
215
        error_report("Failed to read msg header."
216
                " Flags 0x%x instead of 0x%x.", msg->hdr.flags,
217 218 219 220 221
                VHOST_USER_REPLY_MASK | VHOST_USER_VERSION);
        goto fail;
    }

    /* validate message size is sane */
222
    if (msg->hdr.size > VHOST_USER_PAYLOAD_SIZE) {
223
        error_report("Failed to read msg header."
224
                " Size %d exceeds the maximum %zu.", msg->hdr.size,
225 226 227 228
                VHOST_USER_PAYLOAD_SIZE);
        goto fail;
    }

229
    if (msg->hdr.size) {
230
        p += VHOST_USER_HDR_SIZE;
231
        size = msg->hdr.size;
232 233 234
        r = qemu_chr_fe_read_all(chr, p, size);
        if (r != size) {
            error_report("Failed to read msg payload."
235
                         " Read %d instead of %d.", r, msg->hdr.size);
236 237 238 239 240 241 242 243 244 245
            goto fail;
        }
    }

    return 0;

fail:
    return -1;
}

246
static int process_message_reply(struct vhost_dev *dev,
247
                                 const VhostUserMsg *msg)
248
{
249
    VhostUserMsg msg_reply;
250

251
    if ((msg->hdr.flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
252 253 254 255
        return 0;
    }

    if (vhost_user_read(dev, &msg_reply) < 0) {
256 257 258
        return -1;
    }

259
    if (msg_reply.hdr.request != msg->hdr.request) {
260 261
        error_report("Received unexpected msg type."
                     "Expected %d received %d",
262
                     msg->hdr.request, msg_reply.hdr.request);
263 264 265
        return -1;
    }

266
    return msg_reply.payload.u64 ? -1 : 0;
267 268
}

269 270 271 272
static bool vhost_user_one_time_request(VhostUserRequest request)
{
    switch (request) {
    case VHOST_USER_SET_OWNER:
273
    case VHOST_USER_RESET_OWNER:
274 275
    case VHOST_USER_SET_MEM_TABLE:
    case VHOST_USER_GET_QUEUE_NUM:
276
    case VHOST_USER_NET_SET_MTU:
277 278 279 280 281 282 283
        return true;
    default:
        return false;
    }
}

/* most non-init callers ignore the error */
284 285 286
static int vhost_user_write(struct vhost_dev *dev, VhostUserMsg *msg,
                            int *fds, int fd_num)
{
287 288
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
289
    int ret, size = VHOST_USER_HDR_SIZE + msg->hdr.size;
290

291 292 293 294 295
    /*
     * For non-vring specific requests, like VHOST_USER_SET_MEM_TABLE,
     * we just need send it once in the first time. For later such
     * request, we just ignore it.
     */
296 297
    if (vhost_user_one_time_request(msg->hdr.request) && dev->vq_index != 0) {
        msg->hdr.flags &= ~VHOST_USER_NEED_REPLY_MASK;
298 299 300
        return 0;
    }

301
    if (qemu_chr_fe_set_msgfds(chr, fds, fd_num) < 0) {
302
        error_report("Failed to set msg fds.");
303 304
        return -1;
    }
305

306 307 308 309 310 311 312 313
    ret = qemu_chr_fe_write_all(chr, (const uint8_t *) msg, size);
    if (ret != size) {
        error_report("Failed to write msg."
                     " Wrote %d instead of %d.", ret, size);
        return -1;
    }

    return 0;
314 315
}

316 317
static int vhost_user_set_log_base(struct vhost_dev *dev, uint64_t base,
                                   struct vhost_log *log)
318
{
319 320 321 322 323
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    size_t fd_num = 0;
    bool shmfd = virtio_has_feature(dev->protocol_features,
                                    VHOST_USER_PROTOCOL_F_LOG_SHMFD);
    VhostUserMsg msg = {
324 325
        .hdr.request = VHOST_USER_SET_LOG_BASE,
        .hdr.flags = VHOST_USER_VERSION,
M
Michael S. Tsirkin 已提交
326
        .payload.log.mmap_size = log->size * sizeof(*(log->log)),
327
        .payload.log.mmap_offset = 0,
328
        .hdr.size = sizeof(msg.payload.log),
329 330 331 332 333 334
    };

    if (shmfd && log->fd != -1) {
        fds[fd_num++] = log->fd;
    }

335 336 337
    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }
338 339

    if (shmfd) {
340
        msg.hdr.size = 0;
341
        if (vhost_user_read(dev, &msg) < 0) {
342
            return -1;
343 344
        }

345
        if (msg.hdr.request != VHOST_USER_SET_LOG_BASE) {
346 347
            error_report("Received unexpected msg type. "
                         "Expected %d received %d",
348
                         VHOST_USER_SET_LOG_BASE, msg.hdr.request);
349 350
            return -1;
        }
351
    }
352 353

    return 0;
354 355
}

356 357 358
static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev,
                                             struct vhost_memory *mem)
{
359
    struct vhost_user *u = dev->opaque;
360 361 362 363 364
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    int i, fd;
    size_t fd_num = 0;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);
365 366 367
    VhostUserMsg msg_reply;
    int region_i, msg_i;

368 369 370 371 372 373 374 375 376
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_SET_MEM_TABLE,
        .hdr.flags = VHOST_USER_VERSION,
    };

    if (reply_supported) {
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
    }

377 378 379 380 381 382 383 384 385 386 387
    if (u->region_rb_len < dev->mem->nregions) {
        u->region_rb = g_renew(RAMBlock*, u->region_rb, dev->mem->nregions);
        u->region_rb_offset = g_renew(ram_addr_t, u->region_rb_offset,
                                      dev->mem->nregions);
        memset(&(u->region_rb[u->region_rb_len]), '\0',
               sizeof(RAMBlock *) * (dev->mem->nregions - u->region_rb_len));
        memset(&(u->region_rb_offset[u->region_rb_len]), '\0',
               sizeof(ram_addr_t) * (dev->mem->nregions - u->region_rb_len));
        u->region_rb_len = dev->mem->nregions;
    }

388 389 390 391 392 393 394 395 396 397
    for (i = 0; i < dev->mem->nregions; ++i) {
        struct vhost_memory_region *reg = dev->mem->regions + i;
        ram_addr_t offset;
        MemoryRegion *mr;

        assert((uintptr_t)reg->userspace_addr == reg->userspace_addr);
        mr = memory_region_from_host((void *)(uintptr_t)reg->userspace_addr,
                                     &offset);
        fd = memory_region_get_fd(mr);
        if (fd > 0) {
398 399 400 401 402 403
            trace_vhost_user_set_mem_table_withfd(fd_num, mr->name,
                                                  reg->memory_size,
                                                  reg->guest_phys_addr,
                                                  reg->userspace_addr, offset);
            u->region_rb_offset[i] = offset;
            u->region_rb[i] = mr->ram_block;
404 405 406 407 408 409 410 411
            msg.payload.memory.regions[fd_num].userspace_addr =
                reg->userspace_addr;
            msg.payload.memory.regions[fd_num].memory_size  = reg->memory_size;
            msg.payload.memory.regions[fd_num].guest_phys_addr =
                reg->guest_phys_addr;
            msg.payload.memory.regions[fd_num].mmap_offset = offset;
            assert(fd_num < VHOST_MEMORY_MAX_NREGIONS);
            fds[fd_num++] = fd;
412 413 414
        } else {
            u->region_rb_offset[i] = 0;
            u->region_rb[i] = NULL;
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        }
    }

    msg.payload.memory.nregions = fd_num;

    if (!fd_num) {
        error_report("Failed initializing vhost-user memory map, "
                     "consider using -object memory-backend-file share=on");
        return -1;
    }

    msg.hdr.size = sizeof(msg.payload.memory.nregions);
    msg.hdr.size += sizeof(msg.payload.memory.padding);
    msg.hdr.size += fd_num * sizeof(VhostUserMemoryRegion);

    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
    if (vhost_user_read(dev, &msg_reply) < 0) {
        return -1;
    }

    if (msg_reply.hdr.request != VHOST_USER_SET_MEM_TABLE) {
        error_report("%s: Received unexpected msg type."
                     "Expected %d received %d", __func__,
                     VHOST_USER_SET_MEM_TABLE, msg_reply.hdr.request);
        return -1;
    }
    /* We're using the same structure, just reusing one of the
     * fields, so it should be the same size.
     */
    if (msg_reply.hdr.size != msg.hdr.size) {
        error_report("%s: Unexpected size for postcopy reply "
                     "%d vs %d", __func__, msg_reply.hdr.size, msg.hdr.size);
        return -1;
    }

    memset(u->postcopy_client_bases, 0,
           sizeof(uint64_t) * VHOST_MEMORY_MAX_NREGIONS);

    /* They're in the same order as the regions that were sent
     * but some of the regions were skipped (above) if they
     * didn't have fd's
    */
    for (msg_i = 0, region_i = 0;
         region_i < dev->mem->nregions;
        region_i++) {
        if (msg_i < fd_num &&
            msg_reply.payload.memory.regions[msg_i].guest_phys_addr ==
            dev->mem->regions[region_i].guest_phys_addr) {
            u->postcopy_client_bases[region_i] =
                msg_reply.payload.memory.regions[msg_i].userspace_addr;
            trace_vhost_user_set_mem_table_postcopy(
                msg_reply.payload.memory.regions[msg_i].userspace_addr,
                msg.payload.memory.regions[msg_i].userspace_addr,
                msg_i, region_i);
            msg_i++;
        }
    }
    if (msg_i != fd_num) {
        error_report("%s: postcopy reply not fully consumed "
                     "%d vs %zd",
                     __func__, msg_i, fd_num);
        return -1;
    }
    /* Now we've registered this with the postcopy code, we ack to the client,
     * because now we're in the position to be able to deal with any faults
     * it generates.
     */
    /* TODO: Use this for failure cases as well with a bad value */
    msg.hdr.size = sizeof(msg.payload.u64);
    msg.payload.u64 = 0; /* OK */
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

492 493 494 495 496 497 498
    if (reply_supported) {
        return process_message_reply(dev, &msg);
    }

    return 0;
}

499 500 501
static int vhost_user_set_mem_table(struct vhost_dev *dev,
                                    struct vhost_memory *mem)
{
502
    struct vhost_user *u = dev->opaque;
503 504 505
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    int i, fd;
    size_t fd_num = 0;
506
    bool do_postcopy = u->postcopy_listen && u->postcopy_fd.handler;
507
    bool reply_supported = virtio_has_feature(dev->protocol_features,
508 509
                                          VHOST_USER_PROTOCOL_F_REPLY_ACK) &&
                                          !do_postcopy;
510

511 512 513 514 515 516 517
    if (do_postcopy) {
        /* Postcopy has enough differences that it's best done in it's own
         * version
         */
        return vhost_user_set_mem_table_postcopy(dev, mem);
    }

518
    VhostUserMsg msg = {
519 520
        .hdr.request = VHOST_USER_SET_MEM_TABLE,
        .hdr.flags = VHOST_USER_VERSION,
521 522 523
    };

    if (reply_supported) {
524
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
525 526 527 528 529 530 531 532 533 534 535 536
    }

    for (i = 0; i < dev->mem->nregions; ++i) {
        struct vhost_memory_region *reg = dev->mem->regions + i;
        ram_addr_t offset;
        MemoryRegion *mr;

        assert((uintptr_t)reg->userspace_addr == reg->userspace_addr);
        mr = memory_region_from_host((void *)(uintptr_t)reg->userspace_addr,
                                     &offset);
        fd = memory_region_get_fd(mr);
        if (fd > 0) {
537 538 539 540
            if (fd_num == VHOST_MEMORY_MAX_NREGIONS) {
                error_report("Failed preparing vhost-user memory table msg");
                return -1;
            }
541 542
            msg.payload.memory.regions[fd_num].userspace_addr =
                reg->userspace_addr;
543
            msg.payload.memory.regions[fd_num].memory_size  = reg->memory_size;
544 545
            msg.payload.memory.regions[fd_num].guest_phys_addr =
                reg->guest_phys_addr;
546 547 548 549 550 551 552 553 554 555 556 557 558
            msg.payload.memory.regions[fd_num].mmap_offset = offset;
            fds[fd_num++] = fd;
        }
    }

    msg.payload.memory.nregions = fd_num;

    if (!fd_num) {
        error_report("Failed initializing vhost-user memory map, "
                     "consider using -object memory-backend-file share=on");
        return -1;
    }

559 560 561
    msg.hdr.size = sizeof(msg.payload.memory.nregions);
    msg.hdr.size += sizeof(msg.payload.memory.padding);
    msg.hdr.size += fd_num * sizeof(VhostUserMemoryRegion);
562 563 564 565 566 567

    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }

    if (reply_supported) {
568
        return process_message_reply(dev, &msg);
569 570 571 572 573
    }

    return 0;
}

574 575 576 577
static int vhost_user_set_vring_addr(struct vhost_dev *dev,
                                     struct vhost_vring_addr *addr)
{
    VhostUserMsg msg = {
578 579
        .hdr.request = VHOST_USER_SET_VRING_ADDR,
        .hdr.flags = VHOST_USER_VERSION,
580
        .payload.addr = *addr,
581
        .hdr.size = sizeof(msg.payload.addr),
582
    };
583

584 585 586
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
587

588 589
    return 0;
}
590

591 592 593
static int vhost_user_set_vring_endian(struct vhost_dev *dev,
                                       struct vhost_vring_state *ring)
{
594 595 596
    bool cross_endian = virtio_has_feature(dev->protocol_features,
                                           VHOST_USER_PROTOCOL_F_CROSS_ENDIAN);
    VhostUserMsg msg = {
597 598
        .hdr.request = VHOST_USER_SET_VRING_ENDIAN,
        .hdr.flags = VHOST_USER_VERSION,
599
        .payload.state = *ring,
600
        .hdr.size = sizeof(msg.payload.state),
601 602 603 604 605 606 607 608 609 610 611 612
    };

    if (!cross_endian) {
        error_report("vhost-user trying to send unhandled ioctl");
        return -1;
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    return 0;
613
}
614

615 616 617 618 619
static int vhost_set_vring(struct vhost_dev *dev,
                           unsigned long int request,
                           struct vhost_vring_state *ring)
{
    VhostUserMsg msg = {
620 621
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
622
        .payload.state = *ring,
623
        .hdr.size = sizeof(msg.payload.state),
624 625
    };

626 627 628
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646

    return 0;
}

static int vhost_user_set_vring_num(struct vhost_dev *dev,
                                    struct vhost_vring_state *ring)
{
    return vhost_set_vring(dev, VHOST_USER_SET_VRING_NUM, ring);
}

static int vhost_user_set_vring_base(struct vhost_dev *dev,
                                     struct vhost_vring_state *ring)
{
    return vhost_set_vring(dev, VHOST_USER_SET_VRING_BASE, ring);
}

static int vhost_user_set_vring_enable(struct vhost_dev *dev, int enable)
{
647
    int i;
648

649
    if (!virtio_has_feature(dev->features, VHOST_USER_F_PROTOCOL_FEATURES)) {
650 651 652
        return -1;
    }

653 654 655 656 657 658 659 660
    for (i = 0; i < dev->nvqs; ++i) {
        struct vhost_vring_state state = {
            .index = dev->vq_index + i,
            .num   = enable,
        };

        vhost_set_vring(dev, VHOST_USER_SET_VRING_ENABLE, &state);
    }
661

662 663
    return 0;
}
664 665 666 667 668

static int vhost_user_get_vring_base(struct vhost_dev *dev,
                                     struct vhost_vring_state *ring)
{
    VhostUserMsg msg = {
669 670
        .hdr.request = VHOST_USER_GET_VRING_BASE,
        .hdr.flags = VHOST_USER_VERSION,
671
        .payload.state = *ring,
672
        .hdr.size = sizeof(msg.payload.state),
673 674
    };

675 676 677
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
678 679

    if (vhost_user_read(dev, &msg) < 0) {
680
        return -1;
681 682
    }

683
    if (msg.hdr.request != VHOST_USER_GET_VRING_BASE) {
684
        error_report("Received unexpected msg type. Expected %d received %d",
685
                     VHOST_USER_GET_VRING_BASE, msg.hdr.request);
686 687
        return -1;
    }
688

689
    if (msg.hdr.size != sizeof(msg.payload.state)) {
690 691
        error_report("Received bad msg size.");
        return -1;
692 693
    }

694
    *ring = msg.payload.state;
695

696 697 698
    return 0;
}

699 700 701
static int vhost_set_vring_file(struct vhost_dev *dev,
                                VhostUserRequest request,
                                struct vhost_vring_file *file)
702
{
703 704
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    size_t fd_num = 0;
705
    VhostUserMsg msg = {
706 707
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
708
        .payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK,
709
        .hdr.size = sizeof(msg.payload.u64),
710 711
    };

712 713 714
    if (ioeventfd_enabled() && file->fd > 0) {
        fds[fd_num++] = file->fd;
    } else {
715
        msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK;
716 717
    }

718 719 720
    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }
721

722 723
    return 0;
}
724

725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
static int vhost_user_set_vring_kick(struct vhost_dev *dev,
                                     struct vhost_vring_file *file)
{
    return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_KICK, file);
}

static int vhost_user_set_vring_call(struct vhost_dev *dev,
                                     struct vhost_vring_file *file)
{
    return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_CALL, file);
}

static int vhost_user_set_u64(struct vhost_dev *dev, int request, uint64_t u64)
{
    VhostUserMsg msg = {
740 741
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
742
        .payload.u64 = u64,
743
        .hdr.size = sizeof(msg.payload.u64),
744 745
    };

746 747 748
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767

    return 0;
}

static int vhost_user_set_features(struct vhost_dev *dev,
                                   uint64_t features)
{
    return vhost_user_set_u64(dev, VHOST_USER_SET_FEATURES, features);
}

static int vhost_user_set_protocol_features(struct vhost_dev *dev,
                                            uint64_t features)
{
    return vhost_user_set_u64(dev, VHOST_USER_SET_PROTOCOL_FEATURES, features);
}

static int vhost_user_get_u64(struct vhost_dev *dev, int request, uint64_t *u64)
{
    VhostUserMsg msg = {
768 769
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
770 771 772 773
    };

    if (vhost_user_one_time_request(request) && dev->vq_index != 0) {
        return 0;
774
    }
775

776 777 778
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
779 780

    if (vhost_user_read(dev, &msg) < 0) {
781
        return -1;
782 783
    }

784
    if (msg.hdr.request != request) {
785
        error_report("Received unexpected msg type. Expected %d received %d",
786
                     request, msg.hdr.request);
787 788 789
        return -1;
    }

790
    if (msg.hdr.size != sizeof(msg.payload.u64)) {
791 792 793 794
        error_report("Received bad msg size.");
        return -1;
    }

795
    *u64 = msg.payload.u64;
796 797 798 799 800 801 802 803 804 805 806 807

    return 0;
}

static int vhost_user_get_features(struct vhost_dev *dev, uint64_t *features)
{
    return vhost_user_get_u64(dev, VHOST_USER_GET_FEATURES, features);
}

static int vhost_user_set_owner(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
808 809
        .hdr.request = VHOST_USER_SET_OWNER,
        .hdr.flags = VHOST_USER_VERSION,
810 811
    };

812 813 814
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
815 816 817 818 819 820 821

    return 0;
}

static int vhost_user_reset_device(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
822 823
        .hdr.request = VHOST_USER_RESET_OWNER,
        .hdr.flags = VHOST_USER_VERSION,
824 825
    };

826 827 828
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
829

830 831 832
    return 0;
}

833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
static int vhost_user_slave_handle_config_change(struct vhost_dev *dev)
{
    int ret = -1;

    if (!dev->config_ops) {
        return -1;
    }

    if (dev->config_ops->vhost_dev_config_notifier) {
        ret = dev->config_ops->vhost_dev_config_notifier(dev);
    }

    return ret;
}

848 849 850 851
static void slave_read(void *opaque)
{
    struct vhost_dev *dev = opaque;
    struct vhost_user *u = dev->opaque;
852 853
    VhostUserHeader hdr = { 0, };
    VhostUserPayload payload = { 0, };
854
    int size, ret = 0;
855 856 857 858 859 860 861 862 863 864 865 866
    struct iovec iov;
    struct msghdr msgh;
    int fd = -1;
    char control[CMSG_SPACE(sizeof(fd))];
    struct cmsghdr *cmsg;
    size_t fdsize;

    memset(&msgh, 0, sizeof(msgh));
    msgh.msg_iov = &iov;
    msgh.msg_iovlen = 1;
    msgh.msg_control = control;
    msgh.msg_controllen = sizeof(control);
867 868

    /* Read header */
869 870 871 872
    iov.iov_base = &hdr;
    iov.iov_len = VHOST_USER_HDR_SIZE;

    size = recvmsg(u->slave_fd, &msgh, 0);
873 874 875 876 877
    if (size != VHOST_USER_HDR_SIZE) {
        error_report("Failed to read from slave.");
        goto err;
    }

878 879 880 881 882 883 884 885 886 887 888 889 890 891 892
    if (msgh.msg_flags & MSG_CTRUNC) {
        error_report("Truncated message.");
        goto err;
    }

    for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != NULL;
         cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
            if (cmsg->cmsg_level == SOL_SOCKET &&
                cmsg->cmsg_type == SCM_RIGHTS) {
                    fdsize = cmsg->cmsg_len - CMSG_LEN(0);
                    memcpy(&fd, CMSG_DATA(cmsg), fdsize);
                    break;
            }
    }

893
    if (hdr.size > VHOST_USER_PAYLOAD_SIZE) {
894
        error_report("Failed to read msg header."
895
                " Size %d exceeds the maximum %zu.", hdr.size,
896 897 898 899 900
                VHOST_USER_PAYLOAD_SIZE);
        goto err;
    }

    /* Read payload */
901 902
    size = read(u->slave_fd, &payload, hdr.size);
    if (size != hdr.size) {
903 904 905 906
        error_report("Failed to read payload from slave.");
        goto err;
    }

907
    switch (hdr.request) {
908
    case VHOST_USER_SLAVE_IOTLB_MSG:
909
        ret = vhost_backend_handle_iotlb_msg(dev, &payload.iotlb);
910
        break;
911 912 913
    case VHOST_USER_SLAVE_CONFIG_CHANGE_MSG :
        ret = vhost_user_slave_handle_config_change(dev);
        break;
914 915
    default:
        error_report("Received unexpected msg type.");
916 917 918
        if (fd != -1) {
            close(fd);
        }
919 920 921
        ret = -EINVAL;
    }

922 923 924
    /* Message handlers need to make sure that fd will be consumed. */
    fd = -1;

925 926 927 928
    /*
     * REPLY_ACK feature handling. Other reply types has to be managed
     * directly in their request handlers.
     */
929 930
    if (hdr.flags & VHOST_USER_NEED_REPLY_MASK) {
        struct iovec iovec[2];
931 932


933 934 935 936 937 938 939 940 941 942 943 944 945
        hdr.flags &= ~VHOST_USER_NEED_REPLY_MASK;
        hdr.flags |= VHOST_USER_REPLY_MASK;

        payload.u64 = !!ret;
        hdr.size = sizeof(payload.u64);

        iovec[0].iov_base = &hdr;
        iovec[0].iov_len = VHOST_USER_HDR_SIZE;
        iovec[1].iov_base = &payload;
        iovec[1].iov_len = hdr.size;

        size = writev(u->slave_fd, iovec, ARRAY_SIZE(iovec));
        if (size != VHOST_USER_HDR_SIZE + hdr.size) {
946 947 948 949 950 951 952 953 954 955 956
            error_report("Failed to send msg reply to slave.");
            goto err;
        }
    }

    return;

err:
    qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
    close(u->slave_fd);
    u->slave_fd = -1;
957 958 959
    if (fd != -1) {
        close(fd);
    }
960 961 962 963 964 965
    return;
}

static int vhost_setup_slave_channel(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
966 967
        .hdr.request = VHOST_USER_SET_SLAVE_REQ_FD,
        .hdr.flags = VHOST_USER_VERSION,
968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987
    };
    struct vhost_user *u = dev->opaque;
    int sv[2], ret = 0;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    if (!virtio_has_feature(dev->protocol_features,
                            VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
        return 0;
    }

    if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) == -1) {
        error_report("socketpair() failed");
        return -1;
    }

    u->slave_fd = sv[0];
    qemu_set_fd_handler(u->slave_fd, slave_read, NULL, dev);

    if (reply_supported) {
988
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
    }

    ret = vhost_user_write(dev, &msg, &sv[1], 1);
    if (ret) {
        goto out;
    }

    if (reply_supported) {
        ret = process_message_reply(dev, &msg);
    }

out:
    close(sv[1]);
    if (ret) {
        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
        close(u->slave_fd);
        u->slave_fd = -1;
    }

    return ret;
}

1011 1012 1013 1014 1015 1016 1017 1018
/*
 * Called back from the postcopy fault thread when a fault is received on our
 * ufd.
 * TODO: This is Linux specific
 */
static int vhost_user_postcopy_fault_handler(struct PostCopyFD *pcfd,
                                             void *ufd)
{
1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047
    struct vhost_dev *dev = pcfd->data;
    struct vhost_user *u = dev->opaque;
    struct uffd_msg *msg = ufd;
    uint64_t faultaddr = msg->arg.pagefault.address;
    RAMBlock *rb = NULL;
    uint64_t rb_offset;
    int i;

    trace_vhost_user_postcopy_fault_handler(pcfd->idstr, faultaddr,
                                            dev->mem->nregions);
    for (i = 0; i < MIN(dev->mem->nregions, u->region_rb_len); i++) {
        trace_vhost_user_postcopy_fault_handler_loop(i,
                u->postcopy_client_bases[i], dev->mem->regions[i].memory_size);
        if (faultaddr >= u->postcopy_client_bases[i]) {
            /* Ofset of the fault address in the vhost region */
            uint64_t region_offset = faultaddr - u->postcopy_client_bases[i];
            if (region_offset < dev->mem->regions[i].memory_size) {
                rb_offset = region_offset + u->region_rb_offset[i];
                trace_vhost_user_postcopy_fault_handler_found(i,
                        region_offset, rb_offset);
                rb = u->region_rb[i];
                return postcopy_request_shared_page(pcfd, rb, faultaddr,
                                                    rb_offset);
            }
        }
    }
    error_report("%s: Failed to find region for fault %" PRIx64,
                 __func__, faultaddr);
    return -1;
1048 1049
}

1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078
static int vhost_user_postcopy_waker(struct PostCopyFD *pcfd, RAMBlock *rb,
                                     uint64_t offset)
{
    struct vhost_dev *dev = pcfd->data;
    struct vhost_user *u = dev->opaque;
    int i;

    trace_vhost_user_postcopy_waker(qemu_ram_get_idstr(rb), offset);

    if (!u) {
        return 0;
    }
    /* Translate the offset into an address in the clients address space */
    for (i = 0; i < MIN(dev->mem->nregions, u->region_rb_len); i++) {
        if (u->region_rb[i] == rb &&
            offset >= u->region_rb_offset[i] &&
            offset < (u->region_rb_offset[i] +
                      dev->mem->regions[i].memory_size)) {
            uint64_t client_addr = (offset - u->region_rb_offset[i]) +
                                   u->postcopy_client_bases[i];
            trace_vhost_user_postcopy_waker_found(client_addr);
            return postcopy_wake_shared(pcfd, client_addr, rb);
        }
    }

    trace_vhost_user_postcopy_waker_nomatch(qemu_ram_get_idstr(rb), offset);
    return 0;
}

1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
/*
 * Called at the start of an inbound postcopy on reception of the
 * 'advise' command.
 */
static int vhost_user_postcopy_advise(struct vhost_dev *dev, Error **errp)
{
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
    int ufd;
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_POSTCOPY_ADVISE,
        .hdr.flags = VHOST_USER_VERSION,
    };

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_setg(errp, "Failed to send postcopy_advise to vhost");
        return -1;
    }

    if (vhost_user_read(dev, &msg) < 0) {
        error_setg(errp, "Failed to get postcopy_advise reply from vhost");
        return -1;
    }

    if (msg.hdr.request != VHOST_USER_POSTCOPY_ADVISE) {
        error_setg(errp, "Unexpected msg type. Expected %d received %d",
                     VHOST_USER_POSTCOPY_ADVISE, msg.hdr.request);
        return -1;
    }

    if (msg.hdr.size) {
        error_setg(errp, "Received bad msg size.");
        return -1;
    }
    ufd = qemu_chr_fe_get_msgfd(chr);
    if (ufd < 0) {
        error_setg(errp, "%s: Failed to get ufd", __func__);
        return -1;
    }
1118
    qemu_set_nonblock(ufd);
1119

1120 1121 1122 1123
    /* register ufd with userfault thread */
    u->postcopy_fd.fd = ufd;
    u->postcopy_fd.data = dev;
    u->postcopy_fd.handler = vhost_user_postcopy_fault_handler;
1124
    u->postcopy_fd.waker = vhost_user_postcopy_waker;
1125 1126
    u->postcopy_fd.idstr = "vhost-user"; /* Need to find unique name */
    postcopy_register_shared_ufd(&u->postcopy_fd);
1127 1128 1129
    return 0;
}

1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156
/*
 * Called at the switch to postcopy on reception of the 'listen' command.
 */
static int vhost_user_postcopy_listen(struct vhost_dev *dev, Error **errp)
{
    struct vhost_user *u = dev->opaque;
    int ret;
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_POSTCOPY_LISTEN,
        .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
    };
    u->postcopy_listen = true;
    trace_vhost_user_postcopy_listen();
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_setg(errp, "Failed to send postcopy_listen to vhost");
        return -1;
    }

    ret = process_message_reply(dev, &msg);
    if (ret) {
        error_setg(errp, "Failed to receive reply to postcopy_listen");
        return ret;
    }

    return 0;
}

1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187
/*
 * Called at the end of postcopy
 */
static int vhost_user_postcopy_end(struct vhost_dev *dev, Error **errp)
{
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_POSTCOPY_END,
        .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
    };
    int ret;
    struct vhost_user *u = dev->opaque;

    trace_vhost_user_postcopy_end_entry();
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_setg(errp, "Failed to send postcopy_end to vhost");
        return -1;
    }

    ret = process_message_reply(dev, &msg);
    if (ret) {
        error_setg(errp, "Failed to receive reply to postcopy_end");
        return ret;
    }
    postcopy_unregister_shared_ufd(&u->postcopy_fd);
    u->postcopy_fd.handler = NULL;

    trace_vhost_user_postcopy_end_exit();

    return 0;
}

1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
static int vhost_user_postcopy_notifier(NotifierWithReturn *notifier,
                                        void *opaque)
{
    struct PostcopyNotifyData *pnd = opaque;
    struct vhost_user *u = container_of(notifier, struct vhost_user,
                                         postcopy_notifier);
    struct vhost_dev *dev = u->dev;

    switch (pnd->reason) {
    case POSTCOPY_NOTIFY_PROBE:
        if (!virtio_has_feature(dev->protocol_features,
                                VHOST_USER_PROTOCOL_F_PAGEFAULT)) {
            /* TODO: Get the device name into this error somehow */
            error_setg(pnd->errp,
                       "vhost-user backend not capable of postcopy");
            return -ENOENT;
        }
        break;

1207 1208 1209
    case POSTCOPY_NOTIFY_INBOUND_ADVISE:
        return vhost_user_postcopy_advise(dev, pnd->errp);

1210 1211 1212
    case POSTCOPY_NOTIFY_INBOUND_LISTEN:
        return vhost_user_postcopy_listen(dev, pnd->errp);

1213 1214 1215
    case POSTCOPY_NOTIFY_INBOUND_END:
        return vhost_user_postcopy_end(dev, pnd->errp);

1216 1217 1218 1219 1220 1221 1222 1223
    default:
        /* We ignore notifications we don't know */
        break;
    }

    return 0;
}

1224 1225
static int vhost_user_init(struct vhost_dev *dev, void *opaque)
{
1226
    uint64_t features, protocol_features;
1227
    struct vhost_user *u;
1228 1229
    int err;

1230 1231
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

1232 1233
    u = g_new0(struct vhost_user, 1);
    u->chr = opaque;
1234
    u->slave_fd = -1;
1235
    u->dev = dev;
1236
    dev->opaque = u;
1237

1238
    err = vhost_user_get_features(dev, &features);
1239 1240 1241 1242 1243 1244 1245
    if (err < 0) {
        return err;
    }

    if (virtio_has_feature(features, VHOST_USER_F_PROTOCOL_FEATURES)) {
        dev->backend_features |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;

1246
        err = vhost_user_get_u64(dev, VHOST_USER_GET_PROTOCOL_FEATURES,
1247
                                 &protocol_features);
1248 1249 1250 1251
        if (err < 0) {
            return err;
        }

1252 1253
        dev->protocol_features =
            protocol_features & VHOST_USER_PROTOCOL_FEATURE_MASK;
1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264

        if (!dev->config_ops || !dev->config_ops->vhost_dev_config_notifier) {
            /* Don't acknowledge CONFIG feature if device doesn't support it */
            dev->protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_CONFIG);
        } else if (!(protocol_features &
                    (1ULL << VHOST_USER_PROTOCOL_F_CONFIG))) {
            error_report("Device expects VHOST_USER_PROTOCOL_F_CONFIG "
                    "but backend does not support it.");
            return -1;
        }

1265
        err = vhost_user_set_protocol_features(dev, dev->protocol_features);
1266 1267 1268
        if (err < 0) {
            return err;
        }
1269 1270 1271

        /* query the max queues we support if backend supports Multiple Queue */
        if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_MQ)) {
1272 1273
            err = vhost_user_get_u64(dev, VHOST_USER_GET_QUEUE_NUM,
                                     &dev->max_queues);
1274 1275 1276 1277
            if (err < 0) {
                return err;
            }
        }
1278 1279 1280 1281 1282 1283 1284 1285 1286 1287

        if (virtio_has_feature(features, VIRTIO_F_IOMMU_PLATFORM) &&
                !(virtio_has_feature(dev->protocol_features,
                    VHOST_USER_PROTOCOL_F_SLAVE_REQ) &&
                 virtio_has_feature(dev->protocol_features,
                    VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
            error_report("IOMMU support requires reply-ack and "
                         "slave-req protocol features.");
            return -1;
        }
1288 1289
    }

1290 1291 1292 1293 1294 1295 1296 1297
    if (dev->migration_blocker == NULL &&
        !virtio_has_feature(dev->protocol_features,
                            VHOST_USER_PROTOCOL_F_LOG_SHMFD)) {
        error_setg(&dev->migration_blocker,
                   "Migration disabled: vhost-user backend lacks "
                   "VHOST_USER_PROTOCOL_F_LOG_SHMFD feature.");
    }

1298 1299 1300 1301 1302
    err = vhost_setup_slave_channel(dev);
    if (err < 0) {
        return err;
    }

1303 1304 1305
    u->postcopy_notifier.notify = vhost_user_postcopy_notifier;
    postcopy_add_notifier(&u->postcopy_notifier);

1306 1307 1308 1309 1310
    return 0;
}

static int vhost_user_cleanup(struct vhost_dev *dev)
{
1311 1312
    struct vhost_user *u;

1313 1314
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

1315
    u = dev->opaque;
1316 1317 1318 1319
    if (u->postcopy_notifier.notify) {
        postcopy_remove_notifier(&u->postcopy_notifier);
        u->postcopy_notifier.notify = NULL;
    }
1320
    if (u->slave_fd >= 0) {
1321
        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
1322 1323 1324
        close(u->slave_fd);
        u->slave_fd = -1;
    }
1325 1326 1327 1328 1329
    g_free(u->region_rb);
    u->region_rb = NULL;
    g_free(u->region_rb_offset);
    u->region_rb_offset = NULL;
    u->region_rb_len = 0;
1330
    g_free(u);
1331 1332 1333 1334 1335
    dev->opaque = 0;

    return 0;
}

1336 1337 1338 1339 1340 1341 1342
static int vhost_user_get_vq_index(struct vhost_dev *dev, int idx)
{
    assert(idx >= dev->vq_index && idx < dev->vq_index + dev->nvqs);

    return idx;
}

1343 1344 1345 1346 1347
static int vhost_user_memslots_limit(struct vhost_dev *dev)
{
    return VHOST_MEMORY_MAX_NREGIONS;
}

1348 1349 1350 1351 1352 1353 1354 1355
static bool vhost_user_requires_shm_log(struct vhost_dev *dev)
{
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    return virtio_has_feature(dev->protocol_features,
                              VHOST_USER_PROTOCOL_F_LOG_SHMFD);
}

1356 1357
static int vhost_user_migration_done(struct vhost_dev *dev, char* mac_addr)
{
1358
    VhostUserMsg msg = { };
1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369

    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    /* If guest supports GUEST_ANNOUNCE do nothing */
    if (virtio_has_feature(dev->acked_features, VIRTIO_NET_F_GUEST_ANNOUNCE)) {
        return 0;
    }

    /* if backend supports VHOST_USER_PROTOCOL_F_RARP ask it to send the RARP */
    if (virtio_has_feature(dev->protocol_features,
                           VHOST_USER_PROTOCOL_F_RARP)) {
1370 1371
        msg.hdr.request = VHOST_USER_SEND_RARP;
        msg.hdr.flags = VHOST_USER_VERSION;
1372
        memcpy((char *)&msg.payload.u64, mac_addr, 6);
1373
        msg.hdr.size = sizeof(msg.payload.u64);
1374

1375
        return vhost_user_write(dev, &msg, NULL, 0);
1376 1377 1378 1379
    }
    return -1;
}

1380 1381 1382 1383
static bool vhost_user_can_merge(struct vhost_dev *dev,
                                 uint64_t start1, uint64_t size1,
                                 uint64_t start2, uint64_t size2)
{
1384
    ram_addr_t offset;
1385 1386 1387
    int mfd, rfd;
    MemoryRegion *mr;

1388
    mr = memory_region_from_host((void *)(uintptr_t)start1, &offset);
1389
    mfd = memory_region_get_fd(mr);
1390

1391
    mr = memory_region_from_host((void *)(uintptr_t)start2, &offset);
1392
    rfd = memory_region_get_fd(mr);
1393 1394 1395 1396

    return mfd == rfd;
}

1397 1398 1399 1400 1401 1402 1403 1404 1405 1406
static int vhost_user_net_set_mtu(struct vhost_dev *dev, uint16_t mtu)
{
    VhostUserMsg msg;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    if (!(dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_NET_MTU))) {
        return 0;
    }

1407
    msg.hdr.request = VHOST_USER_NET_SET_MTU;
1408
    msg.payload.u64 = mtu;
1409 1410
    msg.hdr.size = sizeof(msg.payload.u64);
    msg.hdr.flags = VHOST_USER_VERSION;
1411
    if (reply_supported) {
1412
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
1413 1414 1415 1416 1417 1418 1419 1420
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    /* If reply_ack supported, slave has to ack specified MTU is valid */
    if (reply_supported) {
1421
        return process_message_reply(dev, &msg);
1422 1423 1424 1425 1426
    }

    return 0;
}

1427 1428 1429 1430
static int vhost_user_send_device_iotlb_msg(struct vhost_dev *dev,
                                            struct vhost_iotlb_msg *imsg)
{
    VhostUserMsg msg = {
1431 1432 1433
        .hdr.request = VHOST_USER_IOTLB_MSG,
        .hdr.size = sizeof(msg.payload.iotlb),
        .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449
        .payload.iotlb = *imsg,
    };

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -EFAULT;
    }

    return process_message_reply(dev, &msg);
}


static void vhost_user_set_iotlb_callback(struct vhost_dev *dev, int enabled)
{
    /* No-op as the receive channel is not dedicated to IOTLB messages. */
}

1450 1451 1452 1453
static int vhost_user_get_config(struct vhost_dev *dev, uint8_t *config,
                                 uint32_t config_len)
{
    VhostUserMsg msg = {
1454 1455 1456
        .hdr.request = VHOST_USER_GET_CONFIG,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = VHOST_USER_CONFIG_HDR_SIZE + config_len,
1457 1458
    };

1459 1460 1461 1462 1463
    if (!virtio_has_feature(dev->protocol_features,
                VHOST_USER_PROTOCOL_F_CONFIG)) {
        return -1;
    }

1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477
    if (config_len > VHOST_USER_MAX_CONFIG_SIZE) {
        return -1;
    }

    msg.payload.config.offset = 0;
    msg.payload.config.size = config_len;
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    if (vhost_user_read(dev, &msg) < 0) {
        return -1;
    }

1478
    if (msg.hdr.request != VHOST_USER_GET_CONFIG) {
1479
        error_report("Received unexpected msg type. Expected %d received %d",
1480
                     VHOST_USER_GET_CONFIG, msg.hdr.request);
1481 1482 1483
        return -1;
    }

1484
    if (msg.hdr.size != VHOST_USER_CONFIG_HDR_SIZE + config_len) {
1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501
        error_report("Received bad msg size.");
        return -1;
    }

    memcpy(config, msg.payload.config.region, config_len);

    return 0;
}

static int vhost_user_set_config(struct vhost_dev *dev, const uint8_t *data,
                                 uint32_t offset, uint32_t size, uint32_t flags)
{
    uint8_t *p;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    VhostUserMsg msg = {
1502 1503 1504
        .hdr.request = VHOST_USER_SET_CONFIG,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = VHOST_USER_CONFIG_HDR_SIZE + size,
1505 1506
    };

1507 1508 1509 1510 1511
    if (!virtio_has_feature(dev->protocol_features,
                VHOST_USER_PROTOCOL_F_CONFIG)) {
        return -1;
    }

1512
    if (reply_supported) {
1513
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536
    }

    if (size > VHOST_USER_MAX_CONFIG_SIZE) {
        return -1;
    }

    msg.payload.config.offset = offset,
    msg.payload.config.size = size,
    msg.payload.config.flags = flags,
    p = msg.payload.config.region;
    memcpy(p, data, size);

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    if (reply_supported) {
        return process_message_reply(dev, &msg);
    }

    return 0;
}

1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
static int vhost_user_crypto_create_session(struct vhost_dev *dev,
                                            void *session_info,
                                            uint64_t *session_id)
{
    bool crypto_session = virtio_has_feature(dev->protocol_features,
                                       VHOST_USER_PROTOCOL_F_CRYPTO_SESSION);
    CryptoDevBackendSymSessionInfo *sess_info = session_info;
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_CREATE_CRYPTO_SESSION,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = sizeof(msg.payload.session),
    };

    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    if (!crypto_session) {
        error_report("vhost-user trying to send unhandled ioctl");
        return -1;
    }

    memcpy(&msg.payload.session.session_setup_data, sess_info,
              sizeof(CryptoDevBackendSymSessionInfo));
    if (sess_info->key_len) {
        memcpy(&msg.payload.session.key, sess_info->cipher_key,
               sess_info->key_len);
    }
    if (sess_info->auth_key_len > 0) {
        memcpy(&msg.payload.session.auth_key, sess_info->auth_key,
               sess_info->auth_key_len);
    }
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_report("vhost_user_write() return -1, create session failed");
        return -1;
    }

    if (vhost_user_read(dev, &msg) < 0) {
        error_report("vhost_user_read() return -1, create session failed");
        return -1;
    }

    if (msg.hdr.request != VHOST_USER_CREATE_CRYPTO_SESSION) {
        error_report("Received unexpected msg type. Expected %d received %d",
                     VHOST_USER_CREATE_CRYPTO_SESSION, msg.hdr.request);
        return -1;
    }

    if (msg.hdr.size != sizeof(msg.payload.session)) {
        error_report("Received bad msg size.");
        return -1;
    }

    if (msg.payload.session.session_id < 0) {
        error_report("Bad session id: %" PRId64 "",
                              msg.payload.session.session_id);
        return -1;
    }
    *session_id = msg.payload.session.session_id;

    return 0;
}

static int
vhost_user_crypto_close_session(struct vhost_dev *dev, uint64_t session_id)
{
    bool crypto_session = virtio_has_feature(dev->protocol_features,
                                       VHOST_USER_PROTOCOL_F_CRYPTO_SESSION);
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_CLOSE_CRYPTO_SESSION,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = sizeof(msg.payload.u64),
    };
    msg.payload.u64 = session_id;

    if (!crypto_session) {
        error_report("vhost-user trying to send unhandled ioctl");
        return -1;
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_report("vhost_user_write() return -1, close session failed");
        return -1;
    }

    return 0;
}

1623 1624 1625 1626 1627 1628 1629 1630 1631 1632
static bool vhost_user_mem_section_filter(struct vhost_dev *dev,
                                          MemoryRegionSection *section)
{
    bool result;

    result = memory_region_get_fd(section->mr) >= 0;

    return result;
}

1633 1634 1635
const VhostOps user_ops = {
        .backend_type = VHOST_BACKEND_TYPE_USER,
        .vhost_backend_init = vhost_user_init,
1636
        .vhost_backend_cleanup = vhost_user_cleanup,
1637
        .vhost_backend_memslots_limit = vhost_user_memslots_limit,
1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652
        .vhost_set_log_base = vhost_user_set_log_base,
        .vhost_set_mem_table = vhost_user_set_mem_table,
        .vhost_set_vring_addr = vhost_user_set_vring_addr,
        .vhost_set_vring_endian = vhost_user_set_vring_endian,
        .vhost_set_vring_num = vhost_user_set_vring_num,
        .vhost_set_vring_base = vhost_user_set_vring_base,
        .vhost_get_vring_base = vhost_user_get_vring_base,
        .vhost_set_vring_kick = vhost_user_set_vring_kick,
        .vhost_set_vring_call = vhost_user_set_vring_call,
        .vhost_set_features = vhost_user_set_features,
        .vhost_get_features = vhost_user_get_features,
        .vhost_set_owner = vhost_user_set_owner,
        .vhost_reset_device = vhost_user_reset_device,
        .vhost_get_vq_index = vhost_user_get_vq_index,
        .vhost_set_vring_enable = vhost_user_set_vring_enable,
1653
        .vhost_requires_shm_log = vhost_user_requires_shm_log,
1654
        .vhost_migration_done = vhost_user_migration_done,
1655
        .vhost_backend_can_merge = vhost_user_can_merge,
1656
        .vhost_net_set_mtu = vhost_user_net_set_mtu,
1657 1658
        .vhost_set_iotlb_callback = vhost_user_set_iotlb_callback,
        .vhost_send_device_iotlb_msg = vhost_user_send_device_iotlb_msg,
1659 1660
        .vhost_get_config = vhost_user_get_config,
        .vhost_set_config = vhost_user_set_config,
1661 1662
        .vhost_crypto_create_session = vhost_user_crypto_create_session,
        .vhost_crypto_close_session = vhost_user_crypto_close_session,
1663
        .vhost_backend_mem_section_filter = vhost_user_mem_section_filter,
1664
};