diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index 2b5f67c4d3573ba1538e4dd8c80e660730c11950..4b7d4a9863a5170ffa179ea4cbd31822fdc87052 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -694,10 +694,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client) static int virNetClientCallDispatch(virNetClientPtr client) { - size_t i; - if (virNetMessageDecodeHeader(&client->msg) < 0) - return -1; - PROBE(RPC_CLIENT_MSG_RX, "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u", client, client->msg.bufferLength, @@ -706,15 +702,7 @@ virNetClientCallDispatch(virNetClientPtr client) switch (client->msg.header.type) { case VIR_NET_REPLY: /* Normal RPC replies */ - return virNetClientCallDispatchReply(client); - case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */ - if (virNetMessageDecodeNumFDs(&client->msg) < 0) - return -1; - for (i = 0 ; i < client->msg.nfds ; i++) { - if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0) - return -1; - } return virNetClientCallDispatchReply(client); case VIR_NET_MESSAGE: /* Async notifications */ @@ -737,22 +725,29 @@ static ssize_t virNetClientIOWriteMessage(virNetClientPtr client, virNetClientCallPtr thecall) { - ssize_t ret; + ssize_t ret = 0; - ret = virNetSocketWrite(client->sock, - thecall->msg->buffer + thecall->msg->bufferOffset, - thecall->msg->bufferLength - thecall->msg->bufferOffset); - if (ret <= 0) - return ret; + if (thecall->msg->bufferOffset < thecall->msg->bufferLength) { + ret = virNetSocketWrite(client->sock, + thecall->msg->buffer + thecall->msg->bufferOffset, + thecall->msg->bufferLength - thecall->msg->bufferOffset); + if (ret <= 0) + return ret; - thecall->msg->bufferOffset += ret; + thecall->msg->bufferOffset += ret; + } if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { size_t i; - for (i = 0 ; i < thecall->msg->nfds ; i++) { - if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0) + for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) { + int rv; + if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i])) < 0) return -1; + if (rv == 0) /* Blocking */ + return 0; + thecall->msg->donefds++; } + thecall->msg->donefds = 0; thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; if (thecall->expectReply) thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; @@ -821,12 +816,16 @@ virNetClientIOHandleInput(virNetClientPtr client) * EAGAIN */ for (;;) { - ssize_t ret = virNetClientIOReadMessage(client); + ssize_t ret; - if (ret < 0) - return -1; - if (ret == 0) - return 0; /* Blocking on read */ + if (client->msg.nfds == 0) { + ret = virNetClientIOReadMessage(client); + + if (ret < 0) + return -1; + if (ret == 0) + return 0; /* Blocking on read */ + } /* Check for completion of our goal */ if (client->msg.bufferOffset == client->msg.bufferLength) { @@ -842,6 +841,33 @@ virNetClientIOHandleInput(virNetClientPtr client) * next iteration. */ } else { + if (virNetMessageDecodeHeader(&client->msg) < 0) + return -1; + + if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) { + size_t i; + if (virNetMessageDecodeNumFDs(&client->msg) < 0) + return -1; + + for (i = client->msg.donefds ; i < client->msg.nfds ; i++) { + int rv; + if ((rv = virNetSocketRecvFD(client->sock, &(client->msg.fds[i]))) < 0) + return -1; + if (rv == 0) /* Blocking */ + break; + client->msg.donefds++; + } + + if (client->msg.donefds < client->msg.nfds) { + /* Because DecodeHeader/NumFDs reset bufferOffset, we + * put it back to what it was, so everything works + * again next time we run this method + */ + client->msg.bufferOffset = client->msg.bufferLength; + return 0; /* Blocking on more fds */ + } + } + ret = virNetClientCallDispatch(client); client->msg.bufferOffset = client->msg.bufferLength = 0; /* @@ -1257,6 +1283,7 @@ int virNetClientSend(virNetClientPtr client, goto cleanup; } + msg->donefds = 0; if (msg->bufferLength) call->mode = VIR_NET_CLIENT_MODE_WAIT_TX; else diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h index ad634099d83f5218f3f9ea24c3adcd46b8cbb3bc..c54e7c6ea55c35f1746c3f1909dfaf488f676846 100644 --- a/src/rpc/virnetmessage.h +++ b/src/rpc/virnetmessage.h @@ -48,6 +48,7 @@ struct _virNetMessage { size_t nfds; int *fds; + size_t donefds; virNetMessagePtr next; }; diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index 2f5ae8faca676f295621ce495ceab8f6a0e6be2d..cf97b58854dee82c3635c2a2c749605f0dd3f114 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -771,9 +771,11 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client) static void virNetServerClientDispatchRead(virNetServerClientPtr client) { readmore: - if (virNetServerClientRead(client) < 0) { - client->wantClose = true; - return; /* Error */ + if (client->rx->nfds == 0) { + if (virNetServerClientRead(client) < 0) { + client->wantClose = true; + return; /* Error */ + } } if (client->rx->bufferOffset < client->rx->bufferLength) @@ -794,7 +796,7 @@ readmore: goto readmore; } else { /* Grab the completed message */ - virNetMessagePtr msg = virNetMessageQueueServe(&client->rx); + virNetMessagePtr msg = client->rx; virNetServerClientFilterPtr filter; size_t i; @@ -805,20 +807,40 @@ readmore: return; } + /* Now figure out if we need to read more data to get some + * file descriptors */ if (msg->header.type == VIR_NET_CALL_WITH_FDS && virNetMessageDecodeNumFDs(msg) < 0) { virNetMessageFree(msg); client->wantClose = true; - return; + return; /* Error */ } - for (i = 0 ; i < msg->nfds ; i++) { - if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) { + + /* Try getting the file descriptors (may fail if blocking) */ + for (i = msg->donefds ; i < msg->nfds ; i++) { + int rv; + if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) { virNetMessageFree(msg); client->wantClose = true; return; } + if (rv == 0) /* Blocking */ + break; + msg->donefds++; + } + + /* Need to poll() until FDs arrive */ + if (msg->donefds < msg->nfds) { + /* Because DecodeHeader/NumFDs reset bufferOffset, we + * put it back to what it was, so everything works + * again next time we run this method + */ + client->rx->bufferOffset = client->rx->bufferLength; + return; } + /* Definitely finished reading, so remove from queue */ + virNetMessageQueueServe(&client->rx); PROBE(RPC_SERVER_CLIENT_MSG_RX, "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u", client, msg->bufferLength, @@ -912,25 +934,30 @@ static void virNetServerClientDispatchWrite(virNetServerClientPtr client) { while (client->tx) { - ssize_t ret; - - ret = virNetServerClientWrite(client); - if (ret < 0) { - client->wantClose = true; - return; + if (client->tx->bufferOffset < client->tx->bufferLength) { + ssize_t ret; + ret = virNetServerClientWrite(client); + if (ret < 0) { + client->wantClose = true; + return; + } + if (ret == 0) + return; /* Would block on write EAGAIN */ } - if (ret == 0) - return; /* Would block on write EAGAIN */ if (client->tx->bufferOffset == client->tx->bufferLength) { virNetMessagePtr msg; size_t i; - for (i = 0 ; i < client->tx->nfds ; i++) { - if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0) { + for (i = client->tx->donefds ; i < client->tx->nfds ; i++) { + int rv; + if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) { client->wantClose = true; return; } + if (rv == 0) /* Blocking */ + return; + client->tx->donefds++; } #if HAVE_SASL @@ -1041,6 +1068,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client, msg->bufferLength, msg->bufferOffset); virNetServerClientLock(client); + msg->donefds = 0; if (client->sock && !client->wantClose) { PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE, "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u", diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index d832c53fd645690d93fc95fe781dff88b76e1de2..4517d1643583e994a3d70c368fe764d057a1dddc 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -1142,6 +1142,9 @@ ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) } +/* + * Returns 1 if an FD was sent, 0 if it would block, -1 on error + */ int virNetSocketSendFD(virNetSocketPtr sock, int fd) { int ret = -1; @@ -1154,12 +1157,15 @@ int virNetSocketSendFD(virNetSocketPtr sock, int fd) PROBE(RPC_SOCKET_SEND_FD, "sock=%p fd=%d", sock, fd); if (sendfd(sock->fd, fd) < 0) { - virReportSystemError(errno, - _("Failed to send file descriptor %d"), - fd); + if (errno == EAGAIN) + ret = 0; + else + virReportSystemError(errno, + _("Failed to send file descriptor %d"), + fd); goto cleanup; } - ret = 0; + ret = 1; cleanup: virMutexUnlock(&sock->lock); @@ -1167,9 +1173,15 @@ cleanup: } -int virNetSocketRecvFD(virNetSocketPtr sock) +/* + * Returns 1 if an FD was read, 0 if it would block, -1 on error + */ +int virNetSocketRecvFD(virNetSocketPtr sock, int *fd) { int ret = -1; + + *fd = -1; + if (!virNetSocketHasPassFD(sock)) { virNetError(VIR_ERR_INTERNAL_ERROR, _("Receiving file descriptors is not supported on this socket")); @@ -1177,13 +1189,17 @@ int virNetSocketRecvFD(virNetSocketPtr sock) } virMutexLock(&sock->lock); - if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) { - virReportSystemError(errno, "%s", - _("Failed to recv file descriptor")); + if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) { + if (errno == EAGAIN) + ret = 0; + else + virReportSystemError(errno, "%s", + _("Failed to recv file descriptor")); goto cleanup; } PROBE(RPC_SOCKET_RECV_FD, - "sock=%p fd=%d", sock, ret); + "sock=%p fd=%d", sock, *fd); + ret = 1; cleanup: virMutexUnlock(&sock->lock); diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 13cbb14ddf90c6af368c9af965d06457087ab1c2..e444aef3b00fc389a9162e20fc973fcc6727c419 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -97,7 +97,7 @@ ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); int virNetSocketSendFD(virNetSocketPtr sock, int fd); -int virNetSocketRecvFD(virNetSocketPtr sock); +int virNetSocketRecvFD(virNetSocketPtr sock, int *fd); void virNetSocketSetTLSSession(virNetSocketPtr sock, virNetTLSSessionPtr sess);