提交 36a9c83d 编写于 作者: D Daniel P. Berrange

Add client side support for FD passing

Extend the RPC client code to allow file descriptors to be sent
to the server with calls, and received back with replies.

* src/remote/remote_driver.c: Stub extra args
* src/libvirt_private.syms, src/rpc/virnetclient.c,
  src/rpc/virnetclient.h, src/rpc/virnetclientprogram.c,
  src/rpc/virnetclientprogram.h: Extend APIs to allow
  FD passing
上级 b0f996a6
...@@ -1184,6 +1184,10 @@ virFileFdopen; ...@@ -1184,6 +1184,10 @@ virFileFdopen;
virFileRewrite; virFileRewrite;
# virnetclient.h
virNetClientHasPassFD;
# virnetmessage.h # virnetmessage.h
virNetMessageClear; virNetMessageClear;
virNetMessageDecodeNumFDs; virNetMessageDecodeNumFDs;
......
...@@ -4152,6 +4152,7 @@ call (virConnectPtr conn ATTRIBUTE_UNUSED, ...@@ -4152,6 +4152,7 @@ call (virConnectPtr conn ATTRIBUTE_UNUSED,
client, client,
counter, counter,
proc_nr, proc_nr,
0, NULL, NULL, NULL,
args_filter, args, args_filter, args,
ret_filter, ret); ret_filter, ret);
remoteDriverLock(priv); remoteDriverLock(priv);
......
...@@ -258,6 +258,16 @@ int virNetClientDupFD(virNetClientPtr client, bool cloexec) ...@@ -258,6 +258,16 @@ int virNetClientDupFD(virNetClientPtr client, bool cloexec)
} }
bool virNetClientHasPassFD(virNetClientPtr client)
{
bool hasPassFD;
virNetClientLock(client);
hasPassFD = virNetSocketHasPassFD(client->sock);
virNetClientUnlock(client);
return hasPassFD;
}
void virNetClientFree(virNetClientPtr client) void virNetClientFree(virNetClientPtr client)
{ {
int i; int i;
...@@ -684,6 +694,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client) ...@@ -684,6 +694,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
static int static int
virNetClientCallDispatch(virNetClientPtr client) virNetClientCallDispatch(virNetClientPtr client)
{ {
size_t i;
if (virNetMessageDecodeHeader(&client->msg) < 0) if (virNetMessageDecodeHeader(&client->msg) < 0)
return -1; return -1;
...@@ -697,6 +708,15 @@ virNetClientCallDispatch(virNetClientPtr client) ...@@ -697,6 +708,15 @@ virNetClientCallDispatch(virNetClientPtr client)
case VIR_NET_REPLY: /* Normal RPC replies */ case VIR_NET_REPLY: /* Normal RPC replies */
return virNetClientCallDispatchReply(client); return virNetClientCallDispatchReply(client);
case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
return -1;
for (i = 0 ; i < client->msg.nfds ; i++) {
if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
return -1;
}
return virNetClientCallDispatchReply(client);
case VIR_NET_MESSAGE: /* Async notifications */ case VIR_NET_MESSAGE: /* Async notifications */
return virNetClientCallDispatchMessage(client); return virNetClientCallDispatchMessage(client);
...@@ -728,6 +748,11 @@ virNetClientIOWriteMessage(virNetClientPtr client, ...@@ -728,6 +748,11 @@ virNetClientIOWriteMessage(virNetClientPtr client,
thecall->msg->bufferOffset += ret; thecall->msg->bufferOffset += ret;
if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
size_t i;
for (i = 0 ; i < thecall->msg->nfds ; i++) {
if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
return -1;
}
thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
if (thecall->expectReply) if (thecall->expectReply)
thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
......
...@@ -56,6 +56,8 @@ void virNetClientRef(virNetClientPtr client); ...@@ -56,6 +56,8 @@ void virNetClientRef(virNetClientPtr client);
int virNetClientGetFD(virNetClientPtr client); int virNetClientGetFD(virNetClientPtr client);
int virNetClientDupFD(virNetClientPtr client, bool cloexec); int virNetClientDupFD(virNetClientPtr client, bool cloexec);
bool virNetClientHasPassFD(virNetClientPtr client);
int virNetClientAddProgram(virNetClientPtr client, int virNetClientAddProgram(virNetClientPtr client,
virNetClientProgramPtr prog); virNetClientProgramPtr prog);
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include <config.h> #include <config.h>
#include <unistd.h>
#include "virnetclientprogram.h" #include "virnetclientprogram.h"
#include "virnetclient.h" #include "virnetclient.h"
#include "virnetprotocol.h" #include "virnetprotocol.h"
...@@ -29,6 +31,8 @@ ...@@ -29,6 +31,8 @@
#include "memory.h" #include "memory.h"
#include "virterror_internal.h" #include "virterror_internal.h"
#include "logging.h" #include "logging.h"
#include "util.h"
#include "virfile.h"
#define VIR_FROM_THIS VIR_FROM_RPC #define VIR_FROM_THIS VIR_FROM_RPC
#define virNetError(code, ...) \ #define virNetError(code, ...) \
...@@ -267,10 +271,20 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -267,10 +271,20 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
virNetClientPtr client, virNetClientPtr client,
unsigned serial, unsigned serial,
int proc, int proc,
size_t noutfds,
int *outfds,
size_t *ninfds,
int **infds,
xdrproc_t args_filter, void *args, xdrproc_t args_filter, void *args,
xdrproc_t ret_filter, void *ret) xdrproc_t ret_filter, void *ret)
{ {
virNetMessagePtr msg; virNetMessagePtr msg;
size_t i;
if (infds)
*infds = NULL;
if (ninfds)
*ninfds = 0;
if (!(msg = virNetMessageNew(false))) if (!(msg = virNetMessageNew(false)))
return -1; return -1;
...@@ -278,13 +292,38 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -278,13 +292,38 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
msg->header.prog = prog->program; msg->header.prog = prog->program;
msg->header.vers = prog->version; msg->header.vers = prog->version;
msg->header.status = VIR_NET_OK; msg->header.status = VIR_NET_OK;
msg->header.type = VIR_NET_CALL; msg->header.type = noutfds ? VIR_NET_CALL_WITH_FDS : VIR_NET_CALL;
msg->header.serial = serial; msg->header.serial = serial;
msg->header.proc = proc; msg->header.proc = proc;
msg->nfds = noutfds;
if (VIR_ALLOC_N(msg->fds, msg->nfds) < 0) {
virReportOOMError();
goto error;
}
for (i = 0 ; i < msg->nfds ; i++)
msg->fds[i] = -1;
for (i = 0 ; i < msg->nfds ; i++) {
if ((msg->fds[i] = dup(outfds[i])) < 0) {
virReportSystemError(errno,
_("Cannot duplicate FD %d"),
outfds[i]);
goto error;
}
if (virSetInherit(msg->fds[i], false) < 0) {
virReportSystemError(errno,
_("Cannot set close-on-exec %d"),
msg->fds[i]);
goto error;
}
}
if (virNetMessageEncodeHeader(msg) < 0) if (virNetMessageEncodeHeader(msg) < 0)
goto error; goto error;
if (msg->nfds &&
virNetMessageEncodeNumFDs(msg) < 0)
goto error;
if (virNetMessageEncodePayload(msg, args_filter, args) < 0) if (virNetMessageEncodePayload(msg, args_filter, args) < 0)
goto error; goto error;
...@@ -295,7 +334,8 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -295,7 +334,8 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
* virNetClientSend should have validated the reply, * virNetClientSend should have validated the reply,
* but it doesn't hurt to check again. * but it doesn't hurt to check again.
*/ */
if (msg->header.type != VIR_NET_REPLY) { if (msg->header.type != VIR_NET_REPLY &&
msg->header.type != VIR_NET_REPLY_WITH_FDS) {
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("Unexpected message type %d"), msg->header.type); _("Unexpected message type %d"), msg->header.type);
goto error; goto error;
...@@ -315,6 +355,30 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -315,6 +355,30 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
switch (msg->header.status) { switch (msg->header.status) {
case VIR_NET_OK: case VIR_NET_OK:
if (infds && ninfds) {
*ninfds = msg->nfds;
if (VIR_ALLOC_N(*infds, *ninfds) < 0) {
virReportOOMError();
goto error;
}
for (i = 0 ; i < *ninfds ; i++)
*infds[i] = -1;
for (i = 0 ; i < *ninfds ; i++) {
if ((*infds[i] = dup(msg->fds[i])) < 0) {
virReportSystemError(errno,
_("Cannot duplicate FD %d"),
msg->fds[i]);
goto error;
}
if (virSetInherit(*infds[i], false) < 0) {
virReportSystemError(errno,
_("Cannot set close-on-exec %d"),
*infds[i]);
goto error;
}
}
}
if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0) if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0)
goto error; goto error;
break; break;
...@@ -335,5 +399,9 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -335,5 +399,9 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
error: error:
virNetMessageFree(msg); virNetMessageFree(msg);
if (infds && ninfds) {
for (i = 0 ; i < *ninfds ; i++)
VIR_FORCE_CLOSE(*infds[i]);
}
return -1; return -1;
} }
...@@ -77,6 +77,10 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -77,6 +77,10 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
virNetClientPtr client, virNetClientPtr client,
unsigned serial, unsigned serial,
int proc, int proc,
size_t noutfds,
int *outfds,
size_t *ninfds,
int **infds,
xdrproc_t args_filter, void *args, xdrproc_t args_filter, void *args,
xdrproc_t ret_filter, void *ret); xdrproc_t ret_filter, void *ret);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册