diff --git a/nbd.c b/nbd.c index 4aeb80ae1881c92634a5ed66882ffa89711312b8..eb72f4a6e57ac92742a218b74257d307affd4a82 100644 --- a/nbd.c +++ b/nbd.c @@ -109,6 +109,7 @@ struct NBDClient { Coroutine *send_coroutine; int nb_requests; + bool closing; }; /* That's all folks */ @@ -655,19 +656,35 @@ void nbd_client_get(NBDClient *client) void nbd_client_put(NBDClient *client) { if (--client->refcount == 0) { + /* The last reference should be dropped by client->close, + * which is called by nbd_client_close. + */ + assert(client->closing); + + qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL); + close(client->sock); + client->sock = -1; g_free(client); } } -static void nbd_client_close(NBDClient *client) +void nbd_client_close(NBDClient *client) { - qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL); - close(client->sock); - client->sock = -1; + if (client->closing) { + return; + } + + client->closing = true; + + /* Force requests to finish. They will drop their own references, + * then we'll close the socket and free the NBDClient. + */ + shutdown(client->sock, 2); + + /* Also tell the client, so that they release their reference. */ if (client->close) { client->close(client); } - nbd_client_put(client); } static NBDRequest *nbd_request_get(NBDClient *client) @@ -810,14 +827,18 @@ out: static void nbd_trip(void *opaque) { NBDClient *client = opaque; - NBDRequest *req = nbd_request_get(client); NBDExport *exp = client->exp; + NBDRequest *req; struct nbd_request request; struct nbd_reply reply; ssize_t ret; TRACE("Reading request."); + if (client->closing) { + return; + } + req = nbd_request_get(client); ret = nbd_co_receive_request(req, &request); if (ret == -EAGAIN) { goto done; diff --git a/nbd.h b/nbd.h index a9038dc1960705c4db9cc257ac02804bea0fc558..8b84a50ed49e932773498736b794ccb418b92d75 100644 --- a/nbd.h +++ b/nbd.h @@ -84,6 +84,7 @@ void nbd_export_close(NBDExport *exp); NBDClient *nbd_client_new(NBDExport *exp, int csock, void (*close)(NBDClient *)); +void nbd_client_close(NBDClient *client); void nbd_client_get(NBDClient *client); void nbd_client_put(NBDClient *client);