diff --git a/src/libvirt_private.syms b/src/libvirt_private.syms index 8c74f56bce819e71423620e04203dd14b790f4f5..d5368877d67bd503def52b547b1c0f732b9485e1 100644 --- a/src/libvirt_private.syms +++ b/src/libvirt_private.syms @@ -1184,6 +1184,10 @@ virFileFdopen; virFileRewrite; +# virnetclient.h +virNetClientHasPassFD; + + # virnetmessage.h virNetMessageClear; virNetMessageDecodeNumFDs; diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index e98ebd737e235580029ce30eb0fc77080b60887c..382bb421af025f6e4ee3abe58d7ede6803d918b8 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -4152,6 +4152,7 @@ call (virConnectPtr conn ATTRIBUTE_UNUSED, client, counter, proc_nr, + 0, NULL, NULL, NULL, args_filter, args, ret_filter, ret); remoteDriverLock(priv); diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index 085dc8d97e05bacb93fc8d47b2f102bc0e30ce86..2b5f67c4d3573ba1538e4dd8c80e660730c11950 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -258,6 +258,16 @@ int virNetClientDupFD(virNetClientPtr client, bool cloexec) } +bool virNetClientHasPassFD(virNetClientPtr client) +{ + bool hasPassFD; + virNetClientLock(client); + hasPassFD = virNetSocketHasPassFD(client->sock); + virNetClientUnlock(client); + return hasPassFD; +} + + void virNetClientFree(virNetClientPtr client) { int i; @@ -684,6 +694,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client) static int virNetClientCallDispatch(virNetClientPtr client) { + size_t i; if (virNetMessageDecodeHeader(&client->msg) < 0) return -1; @@ -697,6 +708,15 @@ virNetClientCallDispatch(virNetClientPtr client) case VIR_NET_REPLY: /* Normal RPC replies */ return virNetClientCallDispatchReply(client); + case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */ + if (virNetMessageDecodeNumFDs(&client->msg) < 0) + return -1; + for (i = 0 ; i < client->msg.nfds ; i++) { + if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0) + return -1; + } + return virNetClientCallDispatchReply(client); + case VIR_NET_MESSAGE: /* Async notifications */ return virNetClientCallDispatchMessage(client); @@ -728,6 +748,11 @@ virNetClientIOWriteMessage(virNetClientPtr client, thecall->msg->bufferOffset += ret; if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { + size_t i; + for (i = 0 ; i < thecall->msg->nfds ; i++) { + if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0) + return -1; + } thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; if (thecall->expectReply) thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; diff --git a/src/rpc/virnetclient.h b/src/rpc/virnetclient.h index 1fabcfde86b52eace924cd6648582583a1d241e5..fb679e897318c948b0ebb8af8a1e604ac52210f0 100644 --- a/src/rpc/virnetclient.h +++ b/src/rpc/virnetclient.h @@ -56,6 +56,8 @@ void virNetClientRef(virNetClientPtr client); int virNetClientGetFD(virNetClientPtr client); int virNetClientDupFD(virNetClientPtr client, bool cloexec); +bool virNetClientHasPassFD(virNetClientPtr client); + int virNetClientAddProgram(virNetClientPtr client, virNetClientProgramPtr prog); diff --git a/src/rpc/virnetclientprogram.c b/src/rpc/virnetclientprogram.c index 33fa5078b78ba39941e72b57dcd2269d242aa691..36e23841e1021822179310afcf797be70f651231 100644 --- a/src/rpc/virnetclientprogram.c +++ b/src/rpc/virnetclientprogram.c @@ -22,6 +22,8 @@ #include +#include + #include "virnetclientprogram.h" #include "virnetclient.h" #include "virnetprotocol.h" @@ -29,6 +31,8 @@ #include "memory.h" #include "virterror_internal.h" #include "logging.h" +#include "util.h" +#include "virfile.h" #define VIR_FROM_THIS VIR_FROM_RPC #define virNetError(code, ...) \ @@ -267,10 +271,20 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, virNetClientPtr client, unsigned serial, int proc, + size_t noutfds, + int *outfds, + size_t *ninfds, + int **infds, xdrproc_t args_filter, void *args, xdrproc_t ret_filter, void *ret) { virNetMessagePtr msg; + size_t i; + + if (infds) + *infds = NULL; + if (ninfds) + *ninfds = 0; if (!(msg = virNetMessageNew(false))) return -1; @@ -278,13 +292,38 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, msg->header.prog = prog->program; msg->header.vers = prog->version; msg->header.status = VIR_NET_OK; - msg->header.type = VIR_NET_CALL; + msg->header.type = noutfds ? VIR_NET_CALL_WITH_FDS : VIR_NET_CALL; msg->header.serial = serial; msg->header.proc = proc; + msg->nfds = noutfds; + if (VIR_ALLOC_N(msg->fds, msg->nfds) < 0) { + virReportOOMError(); + goto error; + } + for (i = 0 ; i < msg->nfds ; i++) + msg->fds[i] = -1; + for (i = 0 ; i < msg->nfds ; i++) { + if ((msg->fds[i] = dup(outfds[i])) < 0) { + virReportSystemError(errno, + _("Cannot duplicate FD %d"), + outfds[i]); + goto error; + } + if (virSetInherit(msg->fds[i], false) < 0) { + virReportSystemError(errno, + _("Cannot set close-on-exec %d"), + msg->fds[i]); + goto error; + } + } if (virNetMessageEncodeHeader(msg) < 0) goto error; + if (msg->nfds && + virNetMessageEncodeNumFDs(msg) < 0) + goto error; + if (virNetMessageEncodePayload(msg, args_filter, args) < 0) goto error; @@ -295,7 +334,8 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, * virNetClientSend should have validated the reply, * but it doesn't hurt to check again. */ - if (msg->header.type != VIR_NET_REPLY) { + if (msg->header.type != VIR_NET_REPLY && + msg->header.type != VIR_NET_REPLY_WITH_FDS) { virNetError(VIR_ERR_INTERNAL_ERROR, _("Unexpected message type %d"), msg->header.type); goto error; @@ -315,6 +355,30 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, switch (msg->header.status) { case VIR_NET_OK: + if (infds && ninfds) { + *ninfds = msg->nfds; + if (VIR_ALLOC_N(*infds, *ninfds) < 0) { + virReportOOMError(); + goto error; + } + for (i = 0 ; i < *ninfds ; i++) + *infds[i] = -1; + for (i = 0 ; i < *ninfds ; i++) { + if ((*infds[i] = dup(msg->fds[i])) < 0) { + virReportSystemError(errno, + _("Cannot duplicate FD %d"), + msg->fds[i]); + goto error; + } + if (virSetInherit(*infds[i], false) < 0) { + virReportSystemError(errno, + _("Cannot set close-on-exec %d"), + *infds[i]); + goto error; + } + } + + } if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0) goto error; break; @@ -335,5 +399,9 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, error: virNetMessageFree(msg); + if (infds && ninfds) { + for (i = 0 ; i < *ninfds ; i++) + VIR_FORCE_CLOSE(*infds[i]); + } return -1; } diff --git a/src/rpc/virnetclientprogram.h b/src/rpc/virnetclientprogram.h index 82ae2c66fbb05a7b09da7db423d73e89814579a5..14a4c9650077711317638f12decad52d66913166 100644 --- a/src/rpc/virnetclientprogram.h +++ b/src/rpc/virnetclientprogram.h @@ -77,6 +77,10 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, virNetClientPtr client, unsigned serial, int proc, + size_t noutfds, + int *outfds, + size_t *ninfds, + int **infds, xdrproc_t args_filter, void *args, xdrproc_t ret_filter, void *ret);