vhost-user.c 47.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*
 * vhost-user
 *
 * Copyright (c) 2013 Virtual Open Systems Sarl.
 *
 * This work is licensed under the terms of the GNU GPL, version 2 or later.
 * See the COPYING file in the top-level directory.
 *
 */

P
Peter Maydell 已提交
11
#include "qemu/osdep.h"
12
#include "qapi/error.h"
13 14
#include "hw/virtio/vhost.h"
#include "hw/virtio/vhost-backend.h"
15
#include "hw/virtio/virtio-net.h"
16
#include "chardev/char-fe.h"
17 18 19
#include "sysemu/kvm.h"
#include "qemu/error-report.h"
#include "qemu/sockets.h"
20
#include "sysemu/cryptodev.h"
21 22
#include "migration/migration.h"
#include "migration/postcopy-ram.h"
23
#include "trace.h"
24 25 26 27 28

#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <linux/vhost.h>
29
#include <linux/userfaultfd.h>
30 31

#define VHOST_MEMORY_MAX_NREGIONS    8
32
#define VHOST_USER_F_PROTOCOL_FEATURES 30
33

34 35 36 37 38
/*
 * Maximum size of virtio device config space
 */
#define VHOST_USER_MAX_CONFIG_SIZE 256

39 40 41 42
enum VhostUserProtocolFeature {
    VHOST_USER_PROTOCOL_F_MQ = 0,
    VHOST_USER_PROTOCOL_F_LOG_SHMFD = 1,
    VHOST_USER_PROTOCOL_F_RARP = 2,
43
    VHOST_USER_PROTOCOL_F_REPLY_ACK = 3,
44
    VHOST_USER_PROTOCOL_F_NET_MTU = 4,
45
    VHOST_USER_PROTOCOL_F_SLAVE_REQ = 5,
46
    VHOST_USER_PROTOCOL_F_CROSS_ENDIAN = 6,
47
    VHOST_USER_PROTOCOL_F_CRYPTO_SESSION = 7,
48
    VHOST_USER_PROTOCOL_F_PAGEFAULT = 8,
49
    VHOST_USER_PROTOCOL_F_CONFIG = 9,
50 51 52 53
    VHOST_USER_PROTOCOL_F_MAX
};

#define VHOST_USER_PROTOCOL_FEATURE_MASK ((1 << VHOST_USER_PROTOCOL_F_MAX) - 1)
54 55 56 57 58 59

typedef enum VhostUserRequest {
    VHOST_USER_NONE = 0,
    VHOST_USER_GET_FEATURES = 1,
    VHOST_USER_SET_FEATURES = 2,
    VHOST_USER_SET_OWNER = 3,
60
    VHOST_USER_RESET_OWNER = 4,
61 62 63 64 65 66 67 68 69 70
    VHOST_USER_SET_MEM_TABLE = 5,
    VHOST_USER_SET_LOG_BASE = 6,
    VHOST_USER_SET_LOG_FD = 7,
    VHOST_USER_SET_VRING_NUM = 8,
    VHOST_USER_SET_VRING_ADDR = 9,
    VHOST_USER_SET_VRING_BASE = 10,
    VHOST_USER_GET_VRING_BASE = 11,
    VHOST_USER_SET_VRING_KICK = 12,
    VHOST_USER_SET_VRING_CALL = 13,
    VHOST_USER_SET_VRING_ERR = 14,
71 72
    VHOST_USER_GET_PROTOCOL_FEATURES = 15,
    VHOST_USER_SET_PROTOCOL_FEATURES = 16,
73
    VHOST_USER_GET_QUEUE_NUM = 17,
74
    VHOST_USER_SET_VRING_ENABLE = 18,
75
    VHOST_USER_SEND_RARP = 19,
76
    VHOST_USER_NET_SET_MTU = 20,
77
    VHOST_USER_SET_SLAVE_REQ_FD = 21,
78
    VHOST_USER_IOTLB_MSG = 22,
79
    VHOST_USER_SET_VRING_ENDIAN = 23,
80 81
    VHOST_USER_GET_CONFIG = 24,
    VHOST_USER_SET_CONFIG = 25,
82 83
    VHOST_USER_CREATE_CRYPTO_SESSION = 26,
    VHOST_USER_CLOSE_CRYPTO_SESSION = 27,
84
    VHOST_USER_POSTCOPY_ADVISE  = 28,
85
    VHOST_USER_POSTCOPY_LISTEN  = 29,
86
    VHOST_USER_POSTCOPY_END     = 30,
87 88 89
    VHOST_USER_MAX
} VhostUserRequest;

90 91
typedef enum VhostUserSlaveRequest {
    VHOST_USER_SLAVE_NONE = 0,
92
    VHOST_USER_SLAVE_IOTLB_MSG = 1,
93
    VHOST_USER_SLAVE_CONFIG_CHANGE_MSG = 2,
94 95 96
    VHOST_USER_SLAVE_MAX
}  VhostUserSlaveRequest;

97 98 99 100
typedef struct VhostUserMemoryRegion {
    uint64_t guest_phys_addr;
    uint64_t memory_size;
    uint64_t userspace_addr;
101
    uint64_t mmap_offset;
102 103 104 105 106 107 108 109
} VhostUserMemoryRegion;

typedef struct VhostUserMemory {
    uint32_t nregions;
    uint32_t padding;
    VhostUserMemoryRegion regions[VHOST_MEMORY_MAX_NREGIONS];
} VhostUserMemory;

110 111 112 113 114
typedef struct VhostUserLog {
    uint64_t mmap_size;
    uint64_t mmap_offset;
} VhostUserLog;

115 116 117 118 119 120 121
typedef struct VhostUserConfig {
    uint32_t offset;
    uint32_t size;
    uint32_t flags;
    uint8_t region[VHOST_USER_MAX_CONFIG_SIZE];
} VhostUserConfig;

122 123 124 125 126 127 128 129 130 131 132
#define VHOST_CRYPTO_SYM_HMAC_MAX_KEY_LEN    512
#define VHOST_CRYPTO_SYM_CIPHER_MAX_KEY_LEN  64

typedef struct VhostUserCryptoSession {
    /* session id for success, -1 on errors */
    int64_t session_id;
    CryptoDevBackendSymSessionInfo session_setup_data;
    uint8_t key[VHOST_CRYPTO_SYM_CIPHER_MAX_KEY_LEN];
    uint8_t auth_key[VHOST_CRYPTO_SYM_HMAC_MAX_KEY_LEN];
} VhostUserCryptoSession;

133 134 135 136 137
static VhostUserConfig c __attribute__ ((unused));
#define VHOST_USER_CONFIG_HDR_SIZE (sizeof(c.offset) \
                                   + sizeof(c.size) \
                                   + sizeof(c.flags))

138
typedef struct {
139 140 141 142
    VhostUserRequest request;

#define VHOST_USER_VERSION_MASK     (0x3)
#define VHOST_USER_REPLY_MASK       (0x1<<2)
143
#define VHOST_USER_NEED_REPLY_MASK  (0x1 << 3)
144 145
    uint32_t flags;
    uint32_t size; /* the following payload size */
146 147 148
} QEMU_PACKED VhostUserHeader;

typedef union {
149 150 151 152 153 154
#define VHOST_USER_VRING_IDX_MASK   (0xff)
#define VHOST_USER_VRING_NOFD_MASK  (0x1<<8)
        uint64_t u64;
        struct vhost_vring_state state;
        struct vhost_vring_addr addr;
        VhostUserMemory memory;
155
        VhostUserLog log;
156
        struct vhost_iotlb_msg iotlb;
157
        VhostUserConfig config;
158
        VhostUserCryptoSession session;
159 160 161 162 163
} VhostUserPayload;

typedef struct VhostUserMsg {
    VhostUserHeader hdr;
    VhostUserPayload payload;
164 165 166
} QEMU_PACKED VhostUserMsg;

static VhostUserMsg m __attribute__ ((unused));
167
#define VHOST_USER_HDR_SIZE (sizeof(VhostUserHeader))
168

169
#define VHOST_USER_PAYLOAD_SIZE (sizeof(VhostUserPayload))
170 171 172 173

/* The version of the protocol we support */
#define VHOST_USER_VERSION    (0x1)

174
struct vhost_user {
175
    struct vhost_dev *dev;
176
    CharBackend *chr;
177
    int slave_fd;
178
    NotifierWithReturn postcopy_notifier;
179
    struct PostCopyFD  postcopy_fd;
180
    uint64_t           postcopy_client_bases[VHOST_MEMORY_MAX_NREGIONS];
181 182 183 184 185 186 187 188 189
    /* Length of the region_rb and region_rb_offset arrays */
    size_t             region_rb_len;
    /* RAMBlock associated with a given region */
    RAMBlock         **region_rb;
    /* The offset from the start of the RAMBlock to the start of the
     * vhost region.
     */
    ram_addr_t        *region_rb_offset;

190 191
    /* True once we've entered postcopy_listen */
    bool               postcopy_listen;
192 193
};

194 195 196 197 198 199 200
static bool ioeventfd_enabled(void)
{
    return kvm_enabled() && kvm_eventfds_enabled();
}

static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
{
201 202
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
203 204 205 206 207
    uint8_t *p = (uint8_t *) msg;
    int r, size = VHOST_USER_HDR_SIZE;

    r = qemu_chr_fe_read_all(chr, p, size);
    if (r != size) {
208
        error_report("Failed to read msg header. Read %d instead of %d."
209
                     " Original request %d.", r, size, msg->hdr.request);
210 211 212 213
        goto fail;
    }

    /* validate received flags */
214
    if (msg->hdr.flags != (VHOST_USER_REPLY_MASK | VHOST_USER_VERSION)) {
215
        error_report("Failed to read msg header."
216
                " Flags 0x%x instead of 0x%x.", msg->hdr.flags,
217 218 219 220 221
                VHOST_USER_REPLY_MASK | VHOST_USER_VERSION);
        goto fail;
    }

    /* validate message size is sane */
222
    if (msg->hdr.size > VHOST_USER_PAYLOAD_SIZE) {
223
        error_report("Failed to read msg header."
224
                " Size %d exceeds the maximum %zu.", msg->hdr.size,
225 226 227 228
                VHOST_USER_PAYLOAD_SIZE);
        goto fail;
    }

229
    if (msg->hdr.size) {
230
        p += VHOST_USER_HDR_SIZE;
231
        size = msg->hdr.size;
232 233 234
        r = qemu_chr_fe_read_all(chr, p, size);
        if (r != size) {
            error_report("Failed to read msg payload."
235
                         " Read %d instead of %d.", r, msg->hdr.size);
236 237 238 239 240 241 242 243 244 245
            goto fail;
        }
    }

    return 0;

fail:
    return -1;
}

246
static int process_message_reply(struct vhost_dev *dev,
247
                                 const VhostUserMsg *msg)
248
{
249
    VhostUserMsg msg_reply;
250

251
    if ((msg->hdr.flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
252 253 254 255
        return 0;
    }

    if (vhost_user_read(dev, &msg_reply) < 0) {
256 257 258
        return -1;
    }

259
    if (msg_reply.hdr.request != msg->hdr.request) {
260 261
        error_report("Received unexpected msg type."
                     "Expected %d received %d",
262
                     msg->hdr.request, msg_reply.hdr.request);
263 264 265
        return -1;
    }

266
    return msg_reply.payload.u64 ? -1 : 0;
267 268
}

269 270 271 272
static bool vhost_user_one_time_request(VhostUserRequest request)
{
    switch (request) {
    case VHOST_USER_SET_OWNER:
273
    case VHOST_USER_RESET_OWNER:
274 275
    case VHOST_USER_SET_MEM_TABLE:
    case VHOST_USER_GET_QUEUE_NUM:
276
    case VHOST_USER_NET_SET_MTU:
277 278 279 280 281 282 283
        return true;
    default:
        return false;
    }
}

/* most non-init callers ignore the error */
284 285 286
static int vhost_user_write(struct vhost_dev *dev, VhostUserMsg *msg,
                            int *fds, int fd_num)
{
287 288
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
289
    int ret, size = VHOST_USER_HDR_SIZE + msg->hdr.size;
290

291 292 293 294 295
    /*
     * For non-vring specific requests, like VHOST_USER_SET_MEM_TABLE,
     * we just need send it once in the first time. For later such
     * request, we just ignore it.
     */
296 297
    if (vhost_user_one_time_request(msg->hdr.request) && dev->vq_index != 0) {
        msg->hdr.flags &= ~VHOST_USER_NEED_REPLY_MASK;
298 299 300
        return 0;
    }

301
    if (qemu_chr_fe_set_msgfds(chr, fds, fd_num) < 0) {
302
        error_report("Failed to set msg fds.");
303 304
        return -1;
    }
305

306 307 308 309 310 311 312 313
    ret = qemu_chr_fe_write_all(chr, (const uint8_t *) msg, size);
    if (ret != size) {
        error_report("Failed to write msg."
                     " Wrote %d instead of %d.", ret, size);
        return -1;
    }

    return 0;
314 315
}

316 317
static int vhost_user_set_log_base(struct vhost_dev *dev, uint64_t base,
                                   struct vhost_log *log)
318
{
319 320 321 322 323
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    size_t fd_num = 0;
    bool shmfd = virtio_has_feature(dev->protocol_features,
                                    VHOST_USER_PROTOCOL_F_LOG_SHMFD);
    VhostUserMsg msg = {
324 325
        .hdr.request = VHOST_USER_SET_LOG_BASE,
        .hdr.flags = VHOST_USER_VERSION,
M
Michael S. Tsirkin 已提交
326
        .payload.log.mmap_size = log->size * sizeof(*(log->log)),
327
        .payload.log.mmap_offset = 0,
328
        .hdr.size = sizeof(msg.payload.log),
329 330 331 332 333 334
    };

    if (shmfd && log->fd != -1) {
        fds[fd_num++] = log->fd;
    }

335 336 337
    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }
338 339

    if (shmfd) {
340
        msg.hdr.size = 0;
341
        if (vhost_user_read(dev, &msg) < 0) {
342
            return -1;
343 344
        }

345
        if (msg.hdr.request != VHOST_USER_SET_LOG_BASE) {
346 347
            error_report("Received unexpected msg type. "
                         "Expected %d received %d",
348
                         VHOST_USER_SET_LOG_BASE, msg.hdr.request);
349 350
            return -1;
        }
351
    }
352 353

    return 0;
354 355
}

356 357 358
static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev,
                                             struct vhost_memory *mem)
{
359
    struct vhost_user *u = dev->opaque;
360 361 362 363 364
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    int i, fd;
    size_t fd_num = 0;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);
365 366 367
    VhostUserMsg msg_reply;
    int region_i, msg_i;

368 369 370 371 372 373 374 375 376
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_SET_MEM_TABLE,
        .hdr.flags = VHOST_USER_VERSION,
    };

    if (reply_supported) {
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
    }

377 378 379 380 381 382 383 384 385 386 387
    if (u->region_rb_len < dev->mem->nregions) {
        u->region_rb = g_renew(RAMBlock*, u->region_rb, dev->mem->nregions);
        u->region_rb_offset = g_renew(ram_addr_t, u->region_rb_offset,
                                      dev->mem->nregions);
        memset(&(u->region_rb[u->region_rb_len]), '\0',
               sizeof(RAMBlock *) * (dev->mem->nregions - u->region_rb_len));
        memset(&(u->region_rb_offset[u->region_rb_len]), '\0',
               sizeof(ram_addr_t) * (dev->mem->nregions - u->region_rb_len));
        u->region_rb_len = dev->mem->nregions;
    }

388 389 390 391 392 393 394 395 396 397
    for (i = 0; i < dev->mem->nregions; ++i) {
        struct vhost_memory_region *reg = dev->mem->regions + i;
        ram_addr_t offset;
        MemoryRegion *mr;

        assert((uintptr_t)reg->userspace_addr == reg->userspace_addr);
        mr = memory_region_from_host((void *)(uintptr_t)reg->userspace_addr,
                                     &offset);
        fd = memory_region_get_fd(mr);
        if (fd > 0) {
398 399 400 401 402 403
            trace_vhost_user_set_mem_table_withfd(fd_num, mr->name,
                                                  reg->memory_size,
                                                  reg->guest_phys_addr,
                                                  reg->userspace_addr, offset);
            u->region_rb_offset[i] = offset;
            u->region_rb[i] = mr->ram_block;
404 405 406 407 408 409 410 411
            msg.payload.memory.regions[fd_num].userspace_addr =
                reg->userspace_addr;
            msg.payload.memory.regions[fd_num].memory_size  = reg->memory_size;
            msg.payload.memory.regions[fd_num].guest_phys_addr =
                reg->guest_phys_addr;
            msg.payload.memory.regions[fd_num].mmap_offset = offset;
            assert(fd_num < VHOST_MEMORY_MAX_NREGIONS);
            fds[fd_num++] = fd;
412 413 414
        } else {
            u->region_rb_offset[i] = 0;
            u->region_rb[i] = NULL;
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        }
    }

    msg.payload.memory.nregions = fd_num;

    if (!fd_num) {
        error_report("Failed initializing vhost-user memory map, "
                     "consider using -object memory-backend-file share=on");
        return -1;
    }

    msg.hdr.size = sizeof(msg.payload.memory.nregions);
    msg.hdr.size += sizeof(msg.payload.memory.padding);
    msg.hdr.size += fd_num * sizeof(VhostUserMemoryRegion);

    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
    if (vhost_user_read(dev, &msg_reply) < 0) {
        return -1;
    }

    if (msg_reply.hdr.request != VHOST_USER_SET_MEM_TABLE) {
        error_report("%s: Received unexpected msg type."
                     "Expected %d received %d", __func__,
                     VHOST_USER_SET_MEM_TABLE, msg_reply.hdr.request);
        return -1;
    }
    /* We're using the same structure, just reusing one of the
     * fields, so it should be the same size.
     */
    if (msg_reply.hdr.size != msg.hdr.size) {
        error_report("%s: Unexpected size for postcopy reply "
                     "%d vs %d", __func__, msg_reply.hdr.size, msg.hdr.size);
        return -1;
    }

    memset(u->postcopy_client_bases, 0,
           sizeof(uint64_t) * VHOST_MEMORY_MAX_NREGIONS);

    /* They're in the same order as the regions that were sent
     * but some of the regions were skipped (above) if they
     * didn't have fd's
    */
    for (msg_i = 0, region_i = 0;
         region_i < dev->mem->nregions;
        region_i++) {
        if (msg_i < fd_num &&
            msg_reply.payload.memory.regions[msg_i].guest_phys_addr ==
            dev->mem->regions[region_i].guest_phys_addr) {
            u->postcopy_client_bases[region_i] =
                msg_reply.payload.memory.regions[msg_i].userspace_addr;
            trace_vhost_user_set_mem_table_postcopy(
                msg_reply.payload.memory.regions[msg_i].userspace_addr,
                msg.payload.memory.regions[msg_i].userspace_addr,
                msg_i, region_i);
            msg_i++;
        }
    }
    if (msg_i != fd_num) {
        error_report("%s: postcopy reply not fully consumed "
                     "%d vs %zd",
                     __func__, msg_i, fd_num);
        return -1;
    }
    /* Now we've registered this with the postcopy code, we ack to the client,
     * because now we're in the position to be able to deal with any faults
     * it generates.
     */
    /* TODO: Use this for failure cases as well with a bad value */
    msg.hdr.size = sizeof(msg.payload.u64);
    msg.payload.u64 = 0; /* OK */
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

492 493 494 495 496 497 498
    if (reply_supported) {
        return process_message_reply(dev, &msg);
    }

    return 0;
}

499 500 501
static int vhost_user_set_mem_table(struct vhost_dev *dev,
                                    struct vhost_memory *mem)
{
502
    struct vhost_user *u = dev->opaque;
503 504 505
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    int i, fd;
    size_t fd_num = 0;
506
    bool do_postcopy = u->postcopy_listen && u->postcopy_fd.handler;
507
    bool reply_supported = virtio_has_feature(dev->protocol_features,
508 509
                                          VHOST_USER_PROTOCOL_F_REPLY_ACK) &&
                                          !do_postcopy;
510

511 512 513 514 515 516 517
    if (do_postcopy) {
        /* Postcopy has enough differences that it's best done in it's own
         * version
         */
        return vhost_user_set_mem_table_postcopy(dev, mem);
    }

518
    VhostUserMsg msg = {
519 520
        .hdr.request = VHOST_USER_SET_MEM_TABLE,
        .hdr.flags = VHOST_USER_VERSION,
521 522 523
    };

    if (reply_supported) {
524
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
525 526 527 528 529 530 531 532 533 534 535 536
    }

    for (i = 0; i < dev->mem->nregions; ++i) {
        struct vhost_memory_region *reg = dev->mem->regions + i;
        ram_addr_t offset;
        MemoryRegion *mr;

        assert((uintptr_t)reg->userspace_addr == reg->userspace_addr);
        mr = memory_region_from_host((void *)(uintptr_t)reg->userspace_addr,
                                     &offset);
        fd = memory_region_get_fd(mr);
        if (fd > 0) {
537 538 539 540
            if (fd_num == VHOST_MEMORY_MAX_NREGIONS) {
                error_report("Failed preparing vhost-user memory table msg");
                return -1;
            }
541 542
            msg.payload.memory.regions[fd_num].userspace_addr =
                reg->userspace_addr;
543
            msg.payload.memory.regions[fd_num].memory_size  = reg->memory_size;
544 545
            msg.payload.memory.regions[fd_num].guest_phys_addr =
                reg->guest_phys_addr;
546 547 548 549 550 551 552 553 554 555 556 557 558
            msg.payload.memory.regions[fd_num].mmap_offset = offset;
            fds[fd_num++] = fd;
        }
    }

    msg.payload.memory.nregions = fd_num;

    if (!fd_num) {
        error_report("Failed initializing vhost-user memory map, "
                     "consider using -object memory-backend-file share=on");
        return -1;
    }

559 560 561
    msg.hdr.size = sizeof(msg.payload.memory.nregions);
    msg.hdr.size += sizeof(msg.payload.memory.padding);
    msg.hdr.size += fd_num * sizeof(VhostUserMemoryRegion);
562 563 564 565 566 567

    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }

    if (reply_supported) {
568
        return process_message_reply(dev, &msg);
569 570 571 572 573
    }

    return 0;
}

574 575 576 577
static int vhost_user_set_vring_addr(struct vhost_dev *dev,
                                     struct vhost_vring_addr *addr)
{
    VhostUserMsg msg = {
578 579
        .hdr.request = VHOST_USER_SET_VRING_ADDR,
        .hdr.flags = VHOST_USER_VERSION,
580
        .payload.addr = *addr,
581
        .hdr.size = sizeof(msg.payload.addr),
582
    };
583

584 585 586
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
587

588 589
    return 0;
}
590

591 592 593
static int vhost_user_set_vring_endian(struct vhost_dev *dev,
                                       struct vhost_vring_state *ring)
{
594 595 596
    bool cross_endian = virtio_has_feature(dev->protocol_features,
                                           VHOST_USER_PROTOCOL_F_CROSS_ENDIAN);
    VhostUserMsg msg = {
597 598
        .hdr.request = VHOST_USER_SET_VRING_ENDIAN,
        .hdr.flags = VHOST_USER_VERSION,
599
        .payload.state = *ring,
600
        .hdr.size = sizeof(msg.payload.state),
601 602 603 604 605 606 607 608 609 610 611 612
    };

    if (!cross_endian) {
        error_report("vhost-user trying to send unhandled ioctl");
        return -1;
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    return 0;
613
}
614

615 616 617 618 619
static int vhost_set_vring(struct vhost_dev *dev,
                           unsigned long int request,
                           struct vhost_vring_state *ring)
{
    VhostUserMsg msg = {
620 621
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
622
        .payload.state = *ring,
623
        .hdr.size = sizeof(msg.payload.state),
624 625
    };

626 627 628
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646

    return 0;
}

static int vhost_user_set_vring_num(struct vhost_dev *dev,
                                    struct vhost_vring_state *ring)
{
    return vhost_set_vring(dev, VHOST_USER_SET_VRING_NUM, ring);
}

static int vhost_user_set_vring_base(struct vhost_dev *dev,
                                     struct vhost_vring_state *ring)
{
    return vhost_set_vring(dev, VHOST_USER_SET_VRING_BASE, ring);
}

static int vhost_user_set_vring_enable(struct vhost_dev *dev, int enable)
{
647
    int i;
648

649
    if (!virtio_has_feature(dev->features, VHOST_USER_F_PROTOCOL_FEATURES)) {
650 651 652
        return -1;
    }

653 654 655 656 657 658 659 660
    for (i = 0; i < dev->nvqs; ++i) {
        struct vhost_vring_state state = {
            .index = dev->vq_index + i,
            .num   = enable,
        };

        vhost_set_vring(dev, VHOST_USER_SET_VRING_ENABLE, &state);
    }
661

662 663
    return 0;
}
664 665 666 667 668

static int vhost_user_get_vring_base(struct vhost_dev *dev,
                                     struct vhost_vring_state *ring)
{
    VhostUserMsg msg = {
669 670
        .hdr.request = VHOST_USER_GET_VRING_BASE,
        .hdr.flags = VHOST_USER_VERSION,
671
        .payload.state = *ring,
672
        .hdr.size = sizeof(msg.payload.state),
673 674
    };

675 676 677
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
678 679

    if (vhost_user_read(dev, &msg) < 0) {
680
        return -1;
681 682
    }

683
    if (msg.hdr.request != VHOST_USER_GET_VRING_BASE) {
684
        error_report("Received unexpected msg type. Expected %d received %d",
685
                     VHOST_USER_GET_VRING_BASE, msg.hdr.request);
686 687
        return -1;
    }
688

689
    if (msg.hdr.size != sizeof(msg.payload.state)) {
690 691
        error_report("Received bad msg size.");
        return -1;
692 693
    }

694
    *ring = msg.payload.state;
695

696 697 698
    return 0;
}

699 700 701
static int vhost_set_vring_file(struct vhost_dev *dev,
                                VhostUserRequest request,
                                struct vhost_vring_file *file)
702
{
703 704
    int fds[VHOST_MEMORY_MAX_NREGIONS];
    size_t fd_num = 0;
705
    VhostUserMsg msg = {
706 707
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
708
        .payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK,
709
        .hdr.size = sizeof(msg.payload.u64),
710 711
    };

712 713 714
    if (ioeventfd_enabled() && file->fd > 0) {
        fds[fd_num++] = file->fd;
    } else {
715
        msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK;
716 717
    }

718 719 720
    if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
        return -1;
    }
721

722 723
    return 0;
}
724

725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
static int vhost_user_set_vring_kick(struct vhost_dev *dev,
                                     struct vhost_vring_file *file)
{
    return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_KICK, file);
}

static int vhost_user_set_vring_call(struct vhost_dev *dev,
                                     struct vhost_vring_file *file)
{
    return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_CALL, file);
}

static int vhost_user_set_u64(struct vhost_dev *dev, int request, uint64_t u64)
{
    VhostUserMsg msg = {
740 741
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
742
        .payload.u64 = u64,
743
        .hdr.size = sizeof(msg.payload.u64),
744 745
    };

746 747 748
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767

    return 0;
}

static int vhost_user_set_features(struct vhost_dev *dev,
                                   uint64_t features)
{
    return vhost_user_set_u64(dev, VHOST_USER_SET_FEATURES, features);
}

static int vhost_user_set_protocol_features(struct vhost_dev *dev,
                                            uint64_t features)
{
    return vhost_user_set_u64(dev, VHOST_USER_SET_PROTOCOL_FEATURES, features);
}

static int vhost_user_get_u64(struct vhost_dev *dev, int request, uint64_t *u64)
{
    VhostUserMsg msg = {
768 769
        .hdr.request = request,
        .hdr.flags = VHOST_USER_VERSION,
770 771 772 773
    };

    if (vhost_user_one_time_request(request) && dev->vq_index != 0) {
        return 0;
774
    }
775

776 777 778
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
779 780

    if (vhost_user_read(dev, &msg) < 0) {
781
        return -1;
782 783
    }

784
    if (msg.hdr.request != request) {
785
        error_report("Received unexpected msg type. Expected %d received %d",
786
                     request, msg.hdr.request);
787 788 789
        return -1;
    }

790
    if (msg.hdr.size != sizeof(msg.payload.u64)) {
791 792 793 794
        error_report("Received bad msg size.");
        return -1;
    }

795
    *u64 = msg.payload.u64;
796 797 798 799 800 801 802 803 804 805 806 807

    return 0;
}

static int vhost_user_get_features(struct vhost_dev *dev, uint64_t *features)
{
    return vhost_user_get_u64(dev, VHOST_USER_GET_FEATURES, features);
}

static int vhost_user_set_owner(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
808 809
        .hdr.request = VHOST_USER_SET_OWNER,
        .hdr.flags = VHOST_USER_VERSION,
810 811
    };

812 813 814
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
815 816 817 818 819 820 821

    return 0;
}

static int vhost_user_reset_device(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
822 823
        .hdr.request = VHOST_USER_RESET_OWNER,
        .hdr.flags = VHOST_USER_VERSION,
824 825
    };

826 827 828
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }
829

830 831 832
    return 0;
}

833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
static int vhost_user_slave_handle_config_change(struct vhost_dev *dev)
{
    int ret = -1;

    if (!dev->config_ops) {
        return -1;
    }

    if (dev->config_ops->vhost_dev_config_notifier) {
        ret = dev->config_ops->vhost_dev_config_notifier(dev);
    }

    return ret;
}

848 849 850 851
static void slave_read(void *opaque)
{
    struct vhost_dev *dev = opaque;
    struct vhost_user *u = dev->opaque;
852 853
    VhostUserHeader hdr = { 0, };
    VhostUserPayload payload = { 0, };
854 855 856
    int size, ret = 0;

    /* Read header */
857
    size = read(u->slave_fd, &hdr, VHOST_USER_HDR_SIZE);
858 859 860 861 862
    if (size != VHOST_USER_HDR_SIZE) {
        error_report("Failed to read from slave.");
        goto err;
    }

863
    if (hdr.size > VHOST_USER_PAYLOAD_SIZE) {
864
        error_report("Failed to read msg header."
865
                " Size %d exceeds the maximum %zu.", hdr.size,
866 867 868 869 870
                VHOST_USER_PAYLOAD_SIZE);
        goto err;
    }

    /* Read payload */
871 872
    size = read(u->slave_fd, &payload, hdr.size);
    if (size != hdr.size) {
873 874 875 876
        error_report("Failed to read payload from slave.");
        goto err;
    }

877
    switch (hdr.request) {
878
    case VHOST_USER_SLAVE_IOTLB_MSG:
879
        ret = vhost_backend_handle_iotlb_msg(dev, &payload.iotlb);
880
        break;
881 882 883
    case VHOST_USER_SLAVE_CONFIG_CHANGE_MSG :
        ret = vhost_user_slave_handle_config_change(dev);
        break;
884 885 886 887 888 889 890 891 892
    default:
        error_report("Received unexpected msg type.");
        ret = -EINVAL;
    }

    /*
     * REPLY_ACK feature handling. Other reply types has to be managed
     * directly in their request handlers.
     */
893 894
    if (hdr.flags & VHOST_USER_NEED_REPLY_MASK) {
        struct iovec iovec[2];
895 896


897 898 899 900 901 902 903 904 905 906 907 908 909
        hdr.flags &= ~VHOST_USER_NEED_REPLY_MASK;
        hdr.flags |= VHOST_USER_REPLY_MASK;

        payload.u64 = !!ret;
        hdr.size = sizeof(payload.u64);

        iovec[0].iov_base = &hdr;
        iovec[0].iov_len = VHOST_USER_HDR_SIZE;
        iovec[1].iov_base = &payload;
        iovec[1].iov_len = hdr.size;

        size = writev(u->slave_fd, iovec, ARRAY_SIZE(iovec));
        if (size != VHOST_USER_HDR_SIZE + hdr.size) {
910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926
            error_report("Failed to send msg reply to slave.");
            goto err;
        }
    }

    return;

err:
    qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
    close(u->slave_fd);
    u->slave_fd = -1;
    return;
}

static int vhost_setup_slave_channel(struct vhost_dev *dev)
{
    VhostUserMsg msg = {
927 928
        .hdr.request = VHOST_USER_SET_SLAVE_REQ_FD,
        .hdr.flags = VHOST_USER_VERSION,
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948
    };
    struct vhost_user *u = dev->opaque;
    int sv[2], ret = 0;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    if (!virtio_has_feature(dev->protocol_features,
                            VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
        return 0;
    }

    if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) == -1) {
        error_report("socketpair() failed");
        return -1;
    }

    u->slave_fd = sv[0];
    qemu_set_fd_handler(u->slave_fd, slave_read, NULL, dev);

    if (reply_supported) {
949
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971
    }

    ret = vhost_user_write(dev, &msg, &sv[1], 1);
    if (ret) {
        goto out;
    }

    if (reply_supported) {
        ret = process_message_reply(dev, &msg);
    }

out:
    close(sv[1]);
    if (ret) {
        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
        close(u->slave_fd);
        u->slave_fd = -1;
    }

    return ret;
}

972 973 974 975 976 977 978 979
/*
 * Called back from the postcopy fault thread when a fault is received on our
 * ufd.
 * TODO: This is Linux specific
 */
static int vhost_user_postcopy_fault_handler(struct PostCopyFD *pcfd,
                                             void *ufd)
{
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    struct vhost_dev *dev = pcfd->data;
    struct vhost_user *u = dev->opaque;
    struct uffd_msg *msg = ufd;
    uint64_t faultaddr = msg->arg.pagefault.address;
    RAMBlock *rb = NULL;
    uint64_t rb_offset;
    int i;

    trace_vhost_user_postcopy_fault_handler(pcfd->idstr, faultaddr,
                                            dev->mem->nregions);
    for (i = 0; i < MIN(dev->mem->nregions, u->region_rb_len); i++) {
        trace_vhost_user_postcopy_fault_handler_loop(i,
                u->postcopy_client_bases[i], dev->mem->regions[i].memory_size);
        if (faultaddr >= u->postcopy_client_bases[i]) {
            /* Ofset of the fault address in the vhost region */
            uint64_t region_offset = faultaddr - u->postcopy_client_bases[i];
            if (region_offset < dev->mem->regions[i].memory_size) {
                rb_offset = region_offset + u->region_rb_offset[i];
                trace_vhost_user_postcopy_fault_handler_found(i,
                        region_offset, rb_offset);
                rb = u->region_rb[i];
                return postcopy_request_shared_page(pcfd, rb, faultaddr,
                                                    rb_offset);
            }
        }
    }
    error_report("%s: Failed to find region for fault %" PRIx64,
                 __func__, faultaddr);
    return -1;
1009 1010
}

1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
static int vhost_user_postcopy_waker(struct PostCopyFD *pcfd, RAMBlock *rb,
                                     uint64_t offset)
{
    struct vhost_dev *dev = pcfd->data;
    struct vhost_user *u = dev->opaque;
    int i;

    trace_vhost_user_postcopy_waker(qemu_ram_get_idstr(rb), offset);

    if (!u) {
        return 0;
    }
    /* Translate the offset into an address in the clients address space */
    for (i = 0; i < MIN(dev->mem->nregions, u->region_rb_len); i++) {
        if (u->region_rb[i] == rb &&
            offset >= u->region_rb_offset[i] &&
            offset < (u->region_rb_offset[i] +
                      dev->mem->regions[i].memory_size)) {
            uint64_t client_addr = (offset - u->region_rb_offset[i]) +
                                   u->postcopy_client_bases[i];
            trace_vhost_user_postcopy_waker_found(client_addr);
            return postcopy_wake_shared(pcfd, client_addr, rb);
        }
    }

    trace_vhost_user_postcopy_waker_nomatch(qemu_ram_get_idstr(rb), offset);
    return 0;
}

1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078
/*
 * Called at the start of an inbound postcopy on reception of the
 * 'advise' command.
 */
static int vhost_user_postcopy_advise(struct vhost_dev *dev, Error **errp)
{
    struct vhost_user *u = dev->opaque;
    CharBackend *chr = u->chr;
    int ufd;
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_POSTCOPY_ADVISE,
        .hdr.flags = VHOST_USER_VERSION,
    };

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_setg(errp, "Failed to send postcopy_advise to vhost");
        return -1;
    }

    if (vhost_user_read(dev, &msg) < 0) {
        error_setg(errp, "Failed to get postcopy_advise reply from vhost");
        return -1;
    }

    if (msg.hdr.request != VHOST_USER_POSTCOPY_ADVISE) {
        error_setg(errp, "Unexpected msg type. Expected %d received %d",
                     VHOST_USER_POSTCOPY_ADVISE, msg.hdr.request);
        return -1;
    }

    if (msg.hdr.size) {
        error_setg(errp, "Received bad msg size.");
        return -1;
    }
    ufd = qemu_chr_fe_get_msgfd(chr);
    if (ufd < 0) {
        error_setg(errp, "%s: Failed to get ufd", __func__);
        return -1;
    }
1079
    fcntl(ufd, F_SETFL, O_NONBLOCK);
1080

1081 1082 1083 1084
    /* register ufd with userfault thread */
    u->postcopy_fd.fd = ufd;
    u->postcopy_fd.data = dev;
    u->postcopy_fd.handler = vhost_user_postcopy_fault_handler;
1085
    u->postcopy_fd.waker = vhost_user_postcopy_waker;
1086 1087
    u->postcopy_fd.idstr = "vhost-user"; /* Need to find unique name */
    postcopy_register_shared_ufd(&u->postcopy_fd);
1088 1089 1090
    return 0;
}

1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
/*
 * Called at the switch to postcopy on reception of the 'listen' command.
 */
static int vhost_user_postcopy_listen(struct vhost_dev *dev, Error **errp)
{
    struct vhost_user *u = dev->opaque;
    int ret;
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_POSTCOPY_LISTEN,
        .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
    };
    u->postcopy_listen = true;
    trace_vhost_user_postcopy_listen();
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_setg(errp, "Failed to send postcopy_listen to vhost");
        return -1;
    }

    ret = process_message_reply(dev, &msg);
    if (ret) {
        error_setg(errp, "Failed to receive reply to postcopy_listen");
        return ret;
    }

    return 0;
}

1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148
/*
 * Called at the end of postcopy
 */
static int vhost_user_postcopy_end(struct vhost_dev *dev, Error **errp)
{
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_POSTCOPY_END,
        .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
    };
    int ret;
    struct vhost_user *u = dev->opaque;

    trace_vhost_user_postcopy_end_entry();
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_setg(errp, "Failed to send postcopy_end to vhost");
        return -1;
    }

    ret = process_message_reply(dev, &msg);
    if (ret) {
        error_setg(errp, "Failed to receive reply to postcopy_end");
        return ret;
    }
    postcopy_unregister_shared_ufd(&u->postcopy_fd);
    u->postcopy_fd.handler = NULL;

    trace_vhost_user_postcopy_end_exit();

    return 0;
}

1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
static int vhost_user_postcopy_notifier(NotifierWithReturn *notifier,
                                        void *opaque)
{
    struct PostcopyNotifyData *pnd = opaque;
    struct vhost_user *u = container_of(notifier, struct vhost_user,
                                         postcopy_notifier);
    struct vhost_dev *dev = u->dev;

    switch (pnd->reason) {
    case POSTCOPY_NOTIFY_PROBE:
        if (!virtio_has_feature(dev->protocol_features,
                                VHOST_USER_PROTOCOL_F_PAGEFAULT)) {
            /* TODO: Get the device name into this error somehow */
            error_setg(pnd->errp,
                       "vhost-user backend not capable of postcopy");
            return -ENOENT;
        }
        break;

1168 1169 1170
    case POSTCOPY_NOTIFY_INBOUND_ADVISE:
        return vhost_user_postcopy_advise(dev, pnd->errp);

1171 1172 1173
    case POSTCOPY_NOTIFY_INBOUND_LISTEN:
        return vhost_user_postcopy_listen(dev, pnd->errp);

1174 1175 1176
    case POSTCOPY_NOTIFY_INBOUND_END:
        return vhost_user_postcopy_end(dev, pnd->errp);

1177 1178 1179 1180 1181 1182 1183 1184
    default:
        /* We ignore notifications we don't know */
        break;
    }

    return 0;
}

1185 1186
static int vhost_user_init(struct vhost_dev *dev, void *opaque)
{
1187
    uint64_t features, protocol_features;
1188
    struct vhost_user *u;
1189 1190
    int err;

1191 1192
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

1193 1194
    u = g_new0(struct vhost_user, 1);
    u->chr = opaque;
1195
    u->slave_fd = -1;
1196
    u->dev = dev;
1197
    dev->opaque = u;
1198

1199
    err = vhost_user_get_features(dev, &features);
1200 1201 1202 1203 1204 1205 1206
    if (err < 0) {
        return err;
    }

    if (virtio_has_feature(features, VHOST_USER_F_PROTOCOL_FEATURES)) {
        dev->backend_features |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;

1207
        err = vhost_user_get_u64(dev, VHOST_USER_GET_PROTOCOL_FEATURES,
1208
                                 &protocol_features);
1209 1210 1211 1212
        if (err < 0) {
            return err;
        }

1213 1214
        dev->protocol_features =
            protocol_features & VHOST_USER_PROTOCOL_FEATURE_MASK;
1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225

        if (!dev->config_ops || !dev->config_ops->vhost_dev_config_notifier) {
            /* Don't acknowledge CONFIG feature if device doesn't support it */
            dev->protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_CONFIG);
        } else if (!(protocol_features &
                    (1ULL << VHOST_USER_PROTOCOL_F_CONFIG))) {
            error_report("Device expects VHOST_USER_PROTOCOL_F_CONFIG "
                    "but backend does not support it.");
            return -1;
        }

1226
        err = vhost_user_set_protocol_features(dev, dev->protocol_features);
1227 1228 1229
        if (err < 0) {
            return err;
        }
1230 1231 1232

        /* query the max queues we support if backend supports Multiple Queue */
        if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_MQ)) {
1233 1234
            err = vhost_user_get_u64(dev, VHOST_USER_GET_QUEUE_NUM,
                                     &dev->max_queues);
1235 1236 1237 1238
            if (err < 0) {
                return err;
            }
        }
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248

        if (virtio_has_feature(features, VIRTIO_F_IOMMU_PLATFORM) &&
                !(virtio_has_feature(dev->protocol_features,
                    VHOST_USER_PROTOCOL_F_SLAVE_REQ) &&
                 virtio_has_feature(dev->protocol_features,
                    VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
            error_report("IOMMU support requires reply-ack and "
                         "slave-req protocol features.");
            return -1;
        }
1249 1250
    }

1251 1252 1253 1254 1255 1256 1257 1258
    if (dev->migration_blocker == NULL &&
        !virtio_has_feature(dev->protocol_features,
                            VHOST_USER_PROTOCOL_F_LOG_SHMFD)) {
        error_setg(&dev->migration_blocker,
                   "Migration disabled: vhost-user backend lacks "
                   "VHOST_USER_PROTOCOL_F_LOG_SHMFD feature.");
    }

1259 1260 1261 1262 1263
    err = vhost_setup_slave_channel(dev);
    if (err < 0) {
        return err;
    }

1264 1265 1266
    u->postcopy_notifier.notify = vhost_user_postcopy_notifier;
    postcopy_add_notifier(&u->postcopy_notifier);

1267 1268 1269 1270 1271
    return 0;
}

static int vhost_user_cleanup(struct vhost_dev *dev)
{
1272 1273
    struct vhost_user *u;

1274 1275
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

1276
    u = dev->opaque;
1277 1278 1279 1280
    if (u->postcopy_notifier.notify) {
        postcopy_remove_notifier(&u->postcopy_notifier);
        u->postcopy_notifier.notify = NULL;
    }
1281
    if (u->slave_fd >= 0) {
1282
        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
1283 1284 1285
        close(u->slave_fd);
        u->slave_fd = -1;
    }
1286 1287 1288 1289 1290
    g_free(u->region_rb);
    u->region_rb = NULL;
    g_free(u->region_rb_offset);
    u->region_rb_offset = NULL;
    u->region_rb_len = 0;
1291
    g_free(u);
1292 1293 1294 1295 1296
    dev->opaque = 0;

    return 0;
}

1297 1298 1299 1300 1301 1302 1303
static int vhost_user_get_vq_index(struct vhost_dev *dev, int idx)
{
    assert(idx >= dev->vq_index && idx < dev->vq_index + dev->nvqs);

    return idx;
}

1304 1305 1306 1307 1308
static int vhost_user_memslots_limit(struct vhost_dev *dev)
{
    return VHOST_MEMORY_MAX_NREGIONS;
}

1309 1310 1311 1312 1313 1314 1315 1316
static bool vhost_user_requires_shm_log(struct vhost_dev *dev)
{
    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    return virtio_has_feature(dev->protocol_features,
                              VHOST_USER_PROTOCOL_F_LOG_SHMFD);
}

1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330
static int vhost_user_migration_done(struct vhost_dev *dev, char* mac_addr)
{
    VhostUserMsg msg = { 0 };

    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    /* If guest supports GUEST_ANNOUNCE do nothing */
    if (virtio_has_feature(dev->acked_features, VIRTIO_NET_F_GUEST_ANNOUNCE)) {
        return 0;
    }

    /* if backend supports VHOST_USER_PROTOCOL_F_RARP ask it to send the RARP */
    if (virtio_has_feature(dev->protocol_features,
                           VHOST_USER_PROTOCOL_F_RARP)) {
1331 1332
        msg.hdr.request = VHOST_USER_SEND_RARP;
        msg.hdr.flags = VHOST_USER_VERSION;
1333
        memcpy((char *)&msg.payload.u64, mac_addr, 6);
1334
        msg.hdr.size = sizeof(msg.payload.u64);
1335

1336
        return vhost_user_write(dev, &msg, NULL, 0);
1337 1338 1339 1340
    }
    return -1;
}

1341 1342 1343 1344
static bool vhost_user_can_merge(struct vhost_dev *dev,
                                 uint64_t start1, uint64_t size1,
                                 uint64_t start2, uint64_t size2)
{
1345
    ram_addr_t offset;
1346 1347 1348
    int mfd, rfd;
    MemoryRegion *mr;

1349
    mr = memory_region_from_host((void *)(uintptr_t)start1, &offset);
1350
    mfd = memory_region_get_fd(mr);
1351

1352
    mr = memory_region_from_host((void *)(uintptr_t)start2, &offset);
1353
    rfd = memory_region_get_fd(mr);
1354 1355 1356 1357

    return mfd == rfd;
}

1358 1359 1360 1361 1362 1363 1364 1365 1366 1367
static int vhost_user_net_set_mtu(struct vhost_dev *dev, uint16_t mtu)
{
    VhostUserMsg msg;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    if (!(dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_NET_MTU))) {
        return 0;
    }

1368
    msg.hdr.request = VHOST_USER_NET_SET_MTU;
1369
    msg.payload.u64 = mtu;
1370 1371
    msg.hdr.size = sizeof(msg.payload.u64);
    msg.hdr.flags = VHOST_USER_VERSION;
1372
    if (reply_supported) {
1373
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
1374 1375 1376 1377 1378 1379 1380 1381
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    /* If reply_ack supported, slave has to ack specified MTU is valid */
    if (reply_supported) {
1382
        return process_message_reply(dev, &msg);
1383 1384 1385 1386 1387
    }

    return 0;
}

1388 1389 1390 1391
static int vhost_user_send_device_iotlb_msg(struct vhost_dev *dev,
                                            struct vhost_iotlb_msg *imsg)
{
    VhostUserMsg msg = {
1392 1393 1394
        .hdr.request = VHOST_USER_IOTLB_MSG,
        .hdr.size = sizeof(msg.payload.iotlb),
        .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
        .payload.iotlb = *imsg,
    };

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -EFAULT;
    }

    return process_message_reply(dev, &msg);
}


static void vhost_user_set_iotlb_callback(struct vhost_dev *dev, int enabled)
{
    /* No-op as the receive channel is not dedicated to IOTLB messages. */
}

1411 1412 1413 1414
static int vhost_user_get_config(struct vhost_dev *dev, uint8_t *config,
                                 uint32_t config_len)
{
    VhostUserMsg msg = {
1415 1416 1417
        .hdr.request = VHOST_USER_GET_CONFIG,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = VHOST_USER_CONFIG_HDR_SIZE + config_len,
1418 1419
    };

1420 1421 1422 1423 1424
    if (!virtio_has_feature(dev->protocol_features,
                VHOST_USER_PROTOCOL_F_CONFIG)) {
        return -1;
    }

1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438
    if (config_len > VHOST_USER_MAX_CONFIG_SIZE) {
        return -1;
    }

    msg.payload.config.offset = 0;
    msg.payload.config.size = config_len;
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    if (vhost_user_read(dev, &msg) < 0) {
        return -1;
    }

1439
    if (msg.hdr.request != VHOST_USER_GET_CONFIG) {
1440
        error_report("Received unexpected msg type. Expected %d received %d",
1441
                     VHOST_USER_GET_CONFIG, msg.hdr.request);
1442 1443 1444
        return -1;
    }

1445
    if (msg.hdr.size != VHOST_USER_CONFIG_HDR_SIZE + config_len) {
1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462
        error_report("Received bad msg size.");
        return -1;
    }

    memcpy(config, msg.payload.config.region, config_len);

    return 0;
}

static int vhost_user_set_config(struct vhost_dev *dev, const uint8_t *data,
                                 uint32_t offset, uint32_t size, uint32_t flags)
{
    uint8_t *p;
    bool reply_supported = virtio_has_feature(dev->protocol_features,
                                              VHOST_USER_PROTOCOL_F_REPLY_ACK);

    VhostUserMsg msg = {
1463 1464 1465
        .hdr.request = VHOST_USER_SET_CONFIG,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = VHOST_USER_CONFIG_HDR_SIZE + size,
1466 1467
    };

1468 1469 1470 1471 1472
    if (!virtio_has_feature(dev->protocol_features,
                VHOST_USER_PROTOCOL_F_CONFIG)) {
        return -1;
    }

1473
    if (reply_supported) {
1474
        msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497
    }

    if (size > VHOST_USER_MAX_CONFIG_SIZE) {
        return -1;
    }

    msg.payload.config.offset = offset,
    msg.payload.config.size = size,
    msg.payload.config.flags = flags,
    p = msg.payload.config.region;
    memcpy(p, data, size);

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        return -1;
    }

    if (reply_supported) {
        return process_message_reply(dev, &msg);
    }

    return 0;
}

1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583
static int vhost_user_crypto_create_session(struct vhost_dev *dev,
                                            void *session_info,
                                            uint64_t *session_id)
{
    bool crypto_session = virtio_has_feature(dev->protocol_features,
                                       VHOST_USER_PROTOCOL_F_CRYPTO_SESSION);
    CryptoDevBackendSymSessionInfo *sess_info = session_info;
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_CREATE_CRYPTO_SESSION,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = sizeof(msg.payload.session),
    };

    assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);

    if (!crypto_session) {
        error_report("vhost-user trying to send unhandled ioctl");
        return -1;
    }

    memcpy(&msg.payload.session.session_setup_data, sess_info,
              sizeof(CryptoDevBackendSymSessionInfo));
    if (sess_info->key_len) {
        memcpy(&msg.payload.session.key, sess_info->cipher_key,
               sess_info->key_len);
    }
    if (sess_info->auth_key_len > 0) {
        memcpy(&msg.payload.session.auth_key, sess_info->auth_key,
               sess_info->auth_key_len);
    }
    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_report("vhost_user_write() return -1, create session failed");
        return -1;
    }

    if (vhost_user_read(dev, &msg) < 0) {
        error_report("vhost_user_read() return -1, create session failed");
        return -1;
    }

    if (msg.hdr.request != VHOST_USER_CREATE_CRYPTO_SESSION) {
        error_report("Received unexpected msg type. Expected %d received %d",
                     VHOST_USER_CREATE_CRYPTO_SESSION, msg.hdr.request);
        return -1;
    }

    if (msg.hdr.size != sizeof(msg.payload.session)) {
        error_report("Received bad msg size.");
        return -1;
    }

    if (msg.payload.session.session_id < 0) {
        error_report("Bad session id: %" PRId64 "",
                              msg.payload.session.session_id);
        return -1;
    }
    *session_id = msg.payload.session.session_id;

    return 0;
}

static int
vhost_user_crypto_close_session(struct vhost_dev *dev, uint64_t session_id)
{
    bool crypto_session = virtio_has_feature(dev->protocol_features,
                                       VHOST_USER_PROTOCOL_F_CRYPTO_SESSION);
    VhostUserMsg msg = {
        .hdr.request = VHOST_USER_CLOSE_CRYPTO_SESSION,
        .hdr.flags = VHOST_USER_VERSION,
        .hdr.size = sizeof(msg.payload.u64),
    };
    msg.payload.u64 = session_id;

    if (!crypto_session) {
        error_report("vhost-user trying to send unhandled ioctl");
        return -1;
    }

    if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
        error_report("vhost_user_write() return -1, close session failed");
        return -1;
    }

    return 0;
}

1584 1585 1586
const VhostOps user_ops = {
        .backend_type = VHOST_BACKEND_TYPE_USER,
        .vhost_backend_init = vhost_user_init,
1587
        .vhost_backend_cleanup = vhost_user_cleanup,
1588
        .vhost_backend_memslots_limit = vhost_user_memslots_limit,
1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603
        .vhost_set_log_base = vhost_user_set_log_base,
        .vhost_set_mem_table = vhost_user_set_mem_table,
        .vhost_set_vring_addr = vhost_user_set_vring_addr,
        .vhost_set_vring_endian = vhost_user_set_vring_endian,
        .vhost_set_vring_num = vhost_user_set_vring_num,
        .vhost_set_vring_base = vhost_user_set_vring_base,
        .vhost_get_vring_base = vhost_user_get_vring_base,
        .vhost_set_vring_kick = vhost_user_set_vring_kick,
        .vhost_set_vring_call = vhost_user_set_vring_call,
        .vhost_set_features = vhost_user_set_features,
        .vhost_get_features = vhost_user_get_features,
        .vhost_set_owner = vhost_user_set_owner,
        .vhost_reset_device = vhost_user_reset_device,
        .vhost_get_vq_index = vhost_user_get_vq_index,
        .vhost_set_vring_enable = vhost_user_set_vring_enable,
1604
        .vhost_requires_shm_log = vhost_user_requires_shm_log,
1605
        .vhost_migration_done = vhost_user_migration_done,
1606
        .vhost_backend_can_merge = vhost_user_can_merge,
1607
        .vhost_net_set_mtu = vhost_user_net_set_mtu,
1608 1609
        .vhost_set_iotlb_callback = vhost_user_set_iotlb_callback,
        .vhost_send_device_iotlb_msg = vhost_user_send_device_iotlb_msg,
1610 1611
        .vhost_get_config = vhost_user_get_config,
        .vhost_set_config = vhost_user_set_config,
1612 1613
        .vhost_crypto_create_session = vhost_user_crypto_create_session,
        .vhost_crypto_close_session = vhost_user_crypto_close_session,
1614
};