提交 b3fb288e 编写于 作者: D Daniel P. Berrange

Fix tracking of RPC messages wrt streams

Commit 2c85644b attempted to
fix a problem with tracking RPC messages from streams by doing

-            if (msg->header.type == VIR_NET_REPLY) {
+            if (msg->header.type == VIR_NET_REPLY ||
+                (msg->header.type == VIR_NET_STREAM &&
+                 msg->header.status != VIR_NET_CONTINUE)) {
                 client->nrequests--;

In other words any stream packet, with status NET_OK or NET_ERROR
would cause nrequests to be decremented. This is great if the
packet from from a synchronous virStreamFinish or virStreamAbort
API call, but wildly wrong if from a server initiated abort.
The latter resulted in 'nrequests' being decremented below zero.
This then causes all I/O for that client to be stopped.

Instead of trying to infer whether we need to decrement the
nrequests field, from the message type/status, introduce an
explicit 'bool tracked' field to mark whether the virNetMessagePtr
object is subject to tracking.

Also add a virNetMessageClear function to allow a message
contents to be cleared out, without adversely impacting the
'tracked' field as a naive memset() would do

* src/rpc/virnetmessage.c, src/rpc/virnetmessage.h: Add
  a 'bool tracked' field and virNetMessageClear() API
* daemon/remote.c, daemon/stream.c, src/rpc/virnetclientprogram.c,
  src/rpc/virnetclientstream.c, src/rpc/virnetserverclient.c,
  src/rpc/virnetserverprogram.c: Switch over to use
  virNetMessageClear() and pass in the 'bool tracked' value
  when creating messages.
上级 1b72ad2e
...@@ -2495,7 +2495,7 @@ remoteDispatchDomainEventSend(virNetServerClientPtr client, ...@@ -2495,7 +2495,7 @@ remoteDispatchDomainEventSend(virNetServerClientPtr client,
{ {
virNetMessagePtr msg; virNetMessagePtr msg;
if (!(msg = virNetMessageNew())) if (!(msg = virNetMessageNew(false)))
goto cleanup; goto cleanup;
msg->header.prog = virNetServerProgramGetID(program); msg->header.prog = virNetServerProgramGetID(program);
......
...@@ -207,7 +207,7 @@ daemonStreamEvent(virStreamPtr st, int events, void *opaque) ...@@ -207,7 +207,7 @@ daemonStreamEvent(virStreamPtr st, int events, void *opaque)
virNetError(VIR_ERR_RPC, virNetError(VIR_ERR_RPC,
"%s", _("stream had I/O failure")); "%s", _("stream had I/O failure"));
msg = virNetMessageNew(); msg = virNetMessageNew(false);
if (!msg) { if (!msg) {
ret = -1; ret = -1;
} else { } else {
...@@ -344,7 +344,7 @@ int daemonFreeClientStream(virNetServerClientPtr client, ...@@ -344,7 +344,7 @@ int daemonFreeClientStream(virNetServerClientPtr client,
virNetMessagePtr tmp = msg->next; virNetMessagePtr tmp = msg->next;
if (client) { if (client) {
/* Send a dummy reply to free up 'msg' & unblock client rx */ /* Send a dummy reply to free up 'msg' & unblock client rx */
memset(msg, 0, sizeof(*msg)); virNetMessageClear(msg);
msg->header.type = VIR_NET_REPLY; msg->header.type = VIR_NET_REPLY;
if (virNetServerClientSendMessage(client, msg) < 0) { if (virNetServerClientSendMessage(client, msg) < 0) {
virNetServerClientImmediateClose(client); virNetServerClientImmediateClose(client);
...@@ -653,7 +653,7 @@ daemonStreamHandleWrite(virNetServerClientPtr client, ...@@ -653,7 +653,7 @@ daemonStreamHandleWrite(virNetServerClientPtr client,
* its active request count / throttling * its active request count / throttling
*/ */
if (msg->header.status == VIR_NET_CONTINUE) { if (msg->header.status == VIR_NET_CONTINUE) {
memset(msg, 0, sizeof(*msg)); virNetMessageClear(msg);
msg->header.type = VIR_NET_REPLY; msg->header.type = VIR_NET_REPLY;
if (virNetServerClientSendMessage(client, msg) < 0) { if (virNetServerClientSendMessage(client, msg) < 0) {
virNetMessageFree(msg); virNetMessageFree(msg);
...@@ -715,7 +715,7 @@ daemonStreamHandleRead(virNetServerClientPtr client, ...@@ -715,7 +715,7 @@ daemonStreamHandleRead(virNetServerClientPtr client,
memset(&rerr, 0, sizeof(rerr)); memset(&rerr, 0, sizeof(rerr));
if (!(msg = virNetMessageNew())) if (!(msg = virNetMessageNew(false)))
ret = -1; ret = -1;
else else
ret = virNetServerProgramSendStreamError(remoteProgram, ret = virNetServerProgramSendStreamError(remoteProgram,
...@@ -729,7 +729,7 @@ daemonStreamHandleRead(virNetServerClientPtr client, ...@@ -729,7 +729,7 @@ daemonStreamHandleRead(virNetServerClientPtr client,
stream->tx = 0; stream->tx = 0;
if (ret == 0) if (ret == 0)
stream->recvEOF = 1; stream->recvEOF = 1;
if (!(msg = virNetMessageNew())) if (!(msg = virNetMessageNew(false)))
ret = -1; ret = -1;
if (msg) { if (msg) {
......
...@@ -272,7 +272,7 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, ...@@ -272,7 +272,7 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
{ {
virNetMessagePtr msg; virNetMessagePtr msg;
if (!(msg = virNetMessageNew())) if (!(msg = virNetMessageNew(false)))
return -1; return -1;
msg->header.prog = prog->program; msg->header.prog = prog->program;
......
...@@ -328,7 +328,7 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st, ...@@ -328,7 +328,7 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
bool wantReply; bool wantReply;
VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes); VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes);
if (!(msg = virNetMessageNew())) if (!(msg = virNetMessageNew(false)))
return -1; return -1;
virMutexLock(&st->lock); virMutexLock(&st->lock);
...@@ -390,7 +390,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st, ...@@ -390,7 +390,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
goto cleanup; goto cleanup;
} }
if (!(msg = virNetMessageNew())) { if (!(msg = virNetMessageNew(false))) {
virReportOOMError(); virReportOOMError();
goto cleanup; goto cleanup;
} }
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \
__FUNCTION__, __LINE__, __VA_ARGS__) __FUNCTION__, __LINE__, __VA_ARGS__)
virNetMessagePtr virNetMessageNew(void) virNetMessagePtr virNetMessageNew(bool tracked)
{ {
virNetMessagePtr msg; virNetMessagePtr msg;
...@@ -41,11 +41,21 @@ virNetMessagePtr virNetMessageNew(void) ...@@ -41,11 +41,21 @@ virNetMessagePtr virNetMessageNew(void)
return NULL; return NULL;
} }
VIR_DEBUG("msg=%p", msg); msg->tracked = tracked;
VIR_DEBUG("msg=%p tracked=%d", msg, tracked);
return msg; return msg;
} }
void virNetMessageClear(virNetMessagePtr msg)
{
bool tracked = msg->tracked;
memset(msg, 0, sizeof(*msg));
msg->tracked = tracked;
}
void virNetMessageFree(virNetMessagePtr msg) void virNetMessageFree(virNetMessagePtr msg)
{ {
if (!msg) if (!msg)
......
...@@ -35,6 +35,8 @@ typedef void (*virNetMessageFreeCallback)(virNetMessagePtr msg, void *opaque); ...@@ -35,6 +35,8 @@ typedef void (*virNetMessageFreeCallback)(virNetMessagePtr msg, void *opaque);
* use virNetMessageNew() to allocate on the heap * use virNetMessageNew() to allocate on the heap
*/ */
struct _virNetMessage { struct _virNetMessage {
bool tracked;
char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX]; char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX];
size_t bufferLength; size_t bufferLength;
size_t bufferOffset; size_t bufferOffset;
...@@ -48,7 +50,9 @@ struct _virNetMessage { ...@@ -48,7 +50,9 @@ struct _virNetMessage {
}; };
virNetMessagePtr virNetMessageNew(void); virNetMessagePtr virNetMessageNew(bool tracked);
void virNetMessageClear(virNetMessagePtr);
void virNetMessageFree(virNetMessagePtr msg); void virNetMessageFree(virNetMessagePtr msg);
......
...@@ -277,7 +277,7 @@ virNetServerClientCheckAccess(virNetServerClientPtr client) ...@@ -277,7 +277,7 @@ virNetServerClientCheckAccess(virNetServerClientPtr client)
return -1; return -1;
} }
if (!(confirm = virNetMessageNew())) if (!(confirm = virNetMessageNew(false)))
return -1; return -1;
/* Checks have succeeded. Write a '\1' byte back to the client to /* Checks have succeeded. Write a '\1' byte back to the client to
...@@ -323,7 +323,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, ...@@ -323,7 +323,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
virNetTLSContextRef(tls); virNetTLSContextRef(tls);
/* Prepare one for packet receive */ /* Prepare one for packet receive */
if (!(client->rx = virNetMessageNew())) if (!(client->rx = virNetMessageNew(true)))
goto error; goto error;
client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
client->nrequests = 1; client->nrequests = 1;
...@@ -805,7 +805,7 @@ readmore: ...@@ -805,7 +805,7 @@ readmore:
/* Possibly need to create another receive buffer */ /* Possibly need to create another receive buffer */
if (client->nrequests < client->nrequests_max) { if (client->nrequests < client->nrequests_max) {
if (!(client->rx = virNetMessageNew())) { if (!(client->rx = virNetMessageNew(true))) {
client->wantClose = true; client->wantClose = true;
} else { } else {
client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
...@@ -885,16 +885,14 @@ virNetServerClientDispatchWrite(virNetServerClientPtr client) ...@@ -885,16 +885,14 @@ virNetServerClientDispatchWrite(virNetServerClientPtr client)
/* Get finished msg from head of tx queue */ /* Get finished msg from head of tx queue */
msg = virNetMessageQueueServe(&client->tx); msg = virNetMessageQueueServe(&client->tx);
if (msg->header.type == VIR_NET_REPLY || if (msg->tracked) {
(msg->header.type == VIR_NET_STREAM &&
msg->header.status != VIR_NET_CONTINUE)) {
client->nrequests--; client->nrequests--;
/* See if the recv queue is currently throttled */ /* See if the recv queue is currently throttled */
if (!client->rx && if (!client->rx &&
client->nrequests < client->nrequests_max) { client->nrequests < client->nrequests_max) {
/* Ready to recv more messages */ /* Ready to recv more messages */
virNetMessageClear(msg);
client->rx = msg; client->rx = msg;
memset(client->rx, 0, sizeof(*client->rx));
client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
msg = NULL; msg = NULL;
client->nrequests++; client->nrequests++;
......
...@@ -284,7 +284,7 @@ int virNetServerProgramDispatch(virNetServerProgramPtr prog, ...@@ -284,7 +284,7 @@ int virNetServerProgramDispatch(virNetServerProgramPtr prog,
VIR_INFO("Ignoring unexpected stream data serial=%d proc=%d status=%d", VIR_INFO("Ignoring unexpected stream data serial=%d proc=%d status=%d",
msg->header.serial, msg->header.proc, msg->header.status); msg->header.serial, msg->header.proc, msg->header.status);
/* Send a dummy reply to free up 'msg' & unblock client rx */ /* Send a dummy reply to free up 'msg' & unblock client rx */
memset(msg, 0, sizeof(*msg)); virNetMessageClear(msg);
msg->header.type = VIR_NET_REPLY; msg->header.type = VIR_NET_REPLY;
if (virNetServerClientSendMessage(client, msg) < 0) { if (virNetServerClientSendMessage(client, msg) < 0) {
ret = -1; ret = -1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册