diff --git a/nbd.c b/nbd.c index ca18c10a1918a494cc292d4e9cb393e87d67e9d1..6d7d1f8d596963b3ade99d4ad5a3fee4088bd7d6 100644 --- a/nbd.c +++ b/nbd.c @@ -20,6 +20,8 @@ #include "block.h" #include "block_int.h" +#include "qemu-coroutine.h" + #include #include #ifndef _WIN32 @@ -607,6 +609,11 @@ struct NBDClient { NBDExport *exp; int sock; + + Coroutine *recv_coroutine; + + CoMutex send_lock; + Coroutine *send_coroutine; }; static void nbd_client_get(NBDClient *client) @@ -681,13 +688,20 @@ void nbd_export_close(NBDExport *exp) g_free(exp); } -static int nbd_do_send_reply(NBDRequest *req, struct nbd_reply *reply, +static void nbd_read(void *opaque); +static void nbd_restart_write(void *opaque); + +static int nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply, int len) { NBDClient *client = req->client; int csock = client->sock; int rc, ret; + qemu_co_mutex_lock(&client->send_lock); + qemu_set_fd_handler2(csock, NULL, nbd_read, nbd_restart_write, client); + client->send_coroutine = qemu_coroutine_self(); + if (!len) { rc = nbd_send_reply(csock, reply); if (rc == -1) { @@ -697,7 +711,7 @@ static int nbd_do_send_reply(NBDRequest *req, struct nbd_reply *reply, socket_set_cork(csock, 1); rc = nbd_send_reply(csock, reply); if (rc != -1) { - ret = write_sync(csock, req->data, len); + ret = qemu_co_send(csock, req->data, len); if (ret != len) { errno = EIO; rc = -1; @@ -708,15 +722,20 @@ static int nbd_do_send_reply(NBDRequest *req, struct nbd_reply *reply, } socket_set_cork(csock, 0); } + + client->send_coroutine = NULL; + qemu_set_fd_handler2(csock, NULL, nbd_read, NULL, client); + qemu_co_mutex_unlock(&client->send_lock); return rc; } -static int nbd_do_receive_request(NBDRequest *req, struct nbd_request *request) +static int nbd_co_receive_request(NBDRequest *req, struct nbd_request *request) { NBDClient *client = req->client; int csock = client->sock; int rc; + client->recv_coroutine = qemu_coroutine_self(); if (nbd_receive_request(csock, request) == -1) { rc = -EIO; goto out; @@ -741,7 +760,7 @@ static int nbd_do_receive_request(NBDRequest *req, struct nbd_request *request) if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) { TRACE("Reading %u byte(s)", request->len); - if (read_sync(csock, req->data, request->len) != request->len) { + if (qemu_co_recv(csock, req->data, request->len) != request->len) { LOG("reading from socket failed"); rc = -EIO; goto out; @@ -750,21 +769,22 @@ static int nbd_do_receive_request(NBDRequest *req, struct nbd_request *request) rc = 0; out: + client->recv_coroutine = NULL; return rc; } -static int nbd_trip(NBDClient *client) +static void nbd_trip(void *opaque) { + NBDClient *client = opaque; NBDRequest *req = nbd_request_get(client); NBDExport *exp = client->exp; struct nbd_request request; struct nbd_reply reply; - int rc = -1; int ret; TRACE("Reading request."); - ret = nbd_do_receive_request(req, &request); + ret = nbd_co_receive_request(req, &request); if (ret == -EIO) { goto out; } @@ -799,7 +819,7 @@ static int nbd_trip(NBDClient *client) } TRACE("Read %u byte(s)", request.len); - if (nbd_do_send_reply(req, &reply, request.len) < 0) + if (nbd_co_send_reply(req, &reply, request.len) < 0) goto out; break; case NBD_CMD_WRITE: @@ -822,7 +842,7 @@ static int nbd_trip(NBDClient *client) } if (request.type & NBD_CMD_FLAG_FUA) { - ret = bdrv_flush(exp->bs); + ret = bdrv_co_flush(exp->bs); if (ret < 0) { LOG("flush failed"); reply.error = -ret; @@ -830,34 +850,34 @@ static int nbd_trip(NBDClient *client) } } - if (nbd_do_send_reply(req, &reply, 0) < 0) + if (nbd_co_send_reply(req, &reply, 0) < 0) goto out; break; case NBD_CMD_DISC: TRACE("Request type is DISCONNECT"); errno = 0; - return 1; + goto out; case NBD_CMD_FLUSH: TRACE("Request type is FLUSH"); - ret = bdrv_flush(exp->bs); + ret = bdrv_co_flush(exp->bs); if (ret < 0) { LOG("flush failed"); reply.error = -ret; } - if (nbd_do_send_reply(req, &reply, 0) < 0) + if (nbd_co_send_reply(req, &reply, 0) < 0) goto out; break; case NBD_CMD_TRIM: TRACE("Request type is TRIM"); - ret = bdrv_discard(exp->bs, (request.from + exp->dev_offset) / 512, - request.len / 512); + ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512, + request.len / 512); if (ret < 0) { LOG("discard failed"); reply.error = -ret; } - if (nbd_do_send_reply(req, &reply, 0) < 0) + if (nbd_co_send_reply(req, &reply, 0) < 0) goto out; break; default: @@ -865,28 +885,39 @@ static int nbd_trip(NBDClient *client) invalid_request: reply.error = -EINVAL; error_reply: - if (nbd_do_send_reply(req, &reply, 0) == -1) + if (nbd_co_send_reply(req, &reply, 0) == -1) goto out; break; } TRACE("Request/Reply complete"); - rc = 0; + nbd_request_put(req); + return; + out: nbd_request_put(req); - return rc; + nbd_client_close(client); } static void nbd_read(void *opaque) { NBDClient *client = opaque; - if (nbd_trip(client) != 0) { - nbd_client_close(client); + if (client->recv_coroutine) { + qemu_coroutine_enter(client->recv_coroutine, NULL); + } else { + qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client); } } +static void nbd_restart_write(void *opaque) +{ + NBDClient *client = opaque; + + qemu_coroutine_enter(client->send_coroutine, NULL); +} + NBDClient *nbd_client_new(NBDExport *exp, int csock, void (*close)(NBDClient *)) { @@ -899,6 +930,7 @@ NBDClient *nbd_client_new(NBDExport *exp, int csock, client->exp = exp; client->sock = csock; client->close = close; + qemu_co_mutex_init(&client->send_lock); qemu_set_fd_handler2(csock, NULL, nbd_read, NULL, client); return client; }