stream.c 19.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/*
 * stream.c: APIs for managing client streams
 *
 * Copyright (C) 2009 Red Hat, Inc.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */


#include <config.h>

#include "stream.h"
27
#include "remote.h"
28 29
#include "memory.h"
#include "logging.h"
30
#include "virnetserverclient.h"
31 32 33
#include "virterror_internal.h"

#define VIR_FROM_THIS VIR_FROM_STREAMS
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
#define virNetError(code, ...)                                    \
    virReportErrorHelper(VIR_FROM_THIS, code, __FILE__,           \
                         __FUNCTION__, __LINE__, __VA_ARGS__)

struct daemonClientStream {
    daemonClientPrivatePtr priv;

    virNetServerProgramPtr prog;

    virStreamPtr st;
    int procedure;
    int serial;

    unsigned int recvEOF : 1;
    unsigned int closed : 1;

    int filterID;

    virNetMessagePtr rx;
    int tx;

    daemonClientStreamPtr next;
};

59
static int
60 61
daemonStreamHandleWrite(virNetServerClientPtr client,
                        daemonClientStream *stream);
62
static int
63 64
daemonStreamHandleRead(virNetServerClientPtr client,
                       daemonClientStream *stream);
65
static int
66 67 68
daemonStreamHandleFinish(virNetServerClientPtr client,
                         daemonClientStream *stream,
                         virNetMessagePtr msg);
69
static int
70 71 72
daemonStreamHandleAbort(virNetServerClientPtr client,
                        daemonClientStream *stream,
                        virNetMessagePtr msg);
73 74 75 76



static void
77
daemonStreamUpdateEvents(daemonClientStream *stream)
78 79 80 81
{
    int newEvents = 0;
    if (stream->rx)
        newEvents |= VIR_STREAM_EVENT_WRITABLE;
82 83
    if (stream->tx && !stream->recvEOF)
        newEvents |= VIR_STREAM_EVENT_READABLE;
84 85 86 87

    virStreamEventUpdateCallback(stream->st, newEvents);
}

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
/*
 * Invoked when an outgoing data packet message has been fully sent.
 * This simply re-enables TX of further data.
 *
 * The idea is to stop the daemon growing without bound due to
 * fast stream, but slow client
 */
static void
daemonStreamMessageFinished(virNetMessagePtr msg,
                            void *opaque)
{
    daemonClientStream *stream = opaque;
    VIR_DEBUG("stream=%p proc=%d serial=%d",
              stream, msg->header.proc, msg->header.serial);

    stream->tx = 1;
    daemonStreamUpdateEvents(stream);
}
106 107 108 109 110

/*
 * Callback that gets invoked when a stream becomes writable/readable
 */
static void
111
daemonStreamEvent(virStreamPtr st, int events, void *opaque)
112
{
113 114 115
    virNetServerClientPtr client = opaque;
    daemonClientStream *stream;
    daemonClientPrivatePtr priv = virNetServerClientGetPrivateData(client);
116

117
    virMutexLock(&priv->lock);
118

119 120 121 122 123 124
    stream = priv->streams;
    while (stream) {
        if (stream->st == st)
            break;
        stream = stream->next;
    }
125 126 127 128 129 130 131

    if (!stream) {
        VIR_WARN("event for client=%p stream st=%p, but missing stream state", client, st);
        virStreamEventRemoveCallback(st);
        goto cleanup;
    }

132
    VIR_DEBUG("st=%p events=%d EOF=%d closed=%d", st, events, stream->recvEOF, stream->closed);
133 134

    if (events & VIR_STREAM_EVENT_WRITABLE) {
135 136 137
        if (daemonStreamHandleWrite(client, stream) < 0) {
            daemonRemoveClientStream(client, stream);
            virNetServerClientClose(client);
138 139 140 141
            goto cleanup;
        }
    }

142 143 144
    if (!stream->recvEOF &&
        (events & (VIR_STREAM_EVENT_READABLE | VIR_STREAM_EVENT_HANGUP))) {
        events = events & ~(VIR_STREAM_EVENT_READABLE | VIR_STREAM_EVENT_HANGUP);
145 146 147
        if (daemonStreamHandleRead(client, stream) < 0) {
            daemonRemoveClientStream(client, stream);
            virNetServerClientClose(client);
148 149 150 151
            goto cleanup;
        }
    }

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    /* If we have a completion/abort message, always process it */
    if (stream->rx) {
        virNetMessagePtr msg = stream->rx;
        switch (msg->header.status) {
        case VIR_NET_CONTINUE:
            /* nada */
            break;
        case VIR_NET_OK:
            virNetMessageQueueServe(&stream->rx);
            if (daemonStreamHandleFinish(client, stream, msg) < 0) {
                virNetMessageFree(msg);
                daemonRemoveClientStream(client, stream);
                virNetServerClientClose(client);
                goto cleanup;
            }
            break;
        case VIR_NET_ERROR:
        default:
            virNetMessageQueueServe(&stream->rx);
            if (daemonStreamHandleAbort(client, stream, msg) < 0) {
                virNetMessageFree(msg);
                daemonRemoveClientStream(client, stream);
                virNetServerClientClose(client);
                goto cleanup;
            }
            break;
        }
    }

181 182 183
    if (!stream->closed &&
        (events & (VIR_STREAM_EVENT_ERROR | VIR_STREAM_EVENT_HANGUP))) {
        int ret;
184 185 186 187
        virNetMessagePtr msg;
        virNetMessageError rerr;

        memset(&rerr, 0, sizeof(rerr));
188
        stream->closed = 1;
189
        virStreamEventRemoveCallback(stream->st);
190 191
        virStreamAbort(stream->st);
        if (events & VIR_STREAM_EVENT_HANGUP)
192 193
            virNetError(VIR_ERR_RPC,
                        "%s", _("stream had unexpected termination"));
194
        else
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
            virNetError(VIR_ERR_RPC,
                        "%s", _("stream had I/O failure"));

        msg = virNetMessageNew();
        if (!msg) {
            ret = -1;
        } else {
            ret = virNetServerProgramSendStreamError(remoteProgram,
                                                     client,
                                                     msg,
                                                     &rerr,
                                                     stream->procedure,
                                                     stream->serial);
        }
        daemonRemoveClientStream(client, stream);
210
        if (ret < 0)
211
            virNetServerClientClose(client);
212 213 214 215
        goto cleanup;
    }

    if (stream->closed) {
216
        daemonRemoveClientStream(client, stream);
217
    } else {
218
        daemonStreamUpdateEvents(stream);
219 220 221
    }

cleanup:
222
    virMutexUnlock(&priv->lock);
223 224
}

225 226 227 228 229 230 231 232 233 234

/*
 * @client: a locked client object
 *
 * Invoked by the main loop when filtering incoming messages.
 *
 * Returns 1 if the message was processed, 0 if skipped,
 * -1 on fatal client error
 */
static int
235 236 237
daemonStreamFilter(virNetServerClientPtr client,
                   virNetMessagePtr msg,
                   void *opaque)
238
{
239 240
    daemonClientStream *stream = opaque;
    int ret = 0;
241

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    virMutexLock(&stream->priv->lock);

    if (msg->header.type != VIR_NET_STREAM)
        goto cleanup;

    if (!virNetServerProgramMatches(stream->prog, msg))
        goto cleanup;

    if (msg->header.proc != stream->procedure ||
        msg->header.serial != stream->serial)
        goto cleanup;

    VIR_DEBUG("Incoming client=%p, rx=%p, serial=%d, proc=%d, status=%d",
              client, stream->rx, msg->header.proc,
              msg->header.serial, msg->header.status);

    virNetMessageQueuePush(&stream->rx, msg);
    daemonStreamUpdateEvents(stream);
    ret = 1;

cleanup:
    virMutexUnlock(&stream->priv->lock);
    return ret;
265 266 267 268 269
}


/*
 * @conn: a connection object to associate the stream with
270
 * @header: the method call to associate with the stream
271 272 273 274 275
 *
 * Creates a new stream for this conn
 *
 * Returns a new stream object, or NULL upon OOM
 */
276 277 278 279 280
daemonClientStream *
daemonCreateClientStream(virNetServerClientPtr client,
                         virStreamPtr st,
                         virNetServerProgramPtr prog,
                         virNetMessageHeaderPtr header)
281
{
282 283
    daemonClientStream *stream;
    daemonClientPrivatePtr priv = virNetServerClientGetPrivateData(client);
284

285 286
    VIR_DEBUG("client=%p, proc=%d, serial=%d, st=%p",
              client, header->proc, header->serial, st);
287

288 289
    if (VIR_ALLOC(stream) < 0) {
        virReportOOMError();
290
        return NULL;
291
    }
292

293 294 295 296 297 298
    stream->priv = priv;
    stream->prog = prog;
    stream->procedure = header->proc;
    stream->serial = header->serial;
    stream->filterID = -1;
    stream->st = st;
299

300
    virNetServerProgramRef(prog);
301 302 303 304 305 306 307 308 309 310

    return stream;
}

/*
 * @stream: an unused client stream
 *
 * Frees the memory associated with this inactive client
 * stream
 */
311 312
int daemonFreeClientStream(virNetServerClientPtr client,
                           daemonClientStream *stream)
313
{
314 315
    virNetMessagePtr msg;
    int ret = 0;
316 317

    if (!stream)
318 319 320 321
        return 0;

    VIR_DEBUG("client=%p, proc=%d, serial=%d",
              client, stream->procedure, stream->serial);
322

323
    virNetServerProgramFree(stream->prog);
324 325 326

    msg = stream->rx;
    while (msg) {
327 328 329 330 331 332 333 334
        virNetMessagePtr tmp = msg->next;
        /* Send a dummy reply to free up 'msg' & unblock client rx */
        memset(msg, 0, sizeof(*msg));
        if (virNetServerClientSendMessage(client, msg) < 0) {
            virNetServerClientMarkClose(client);
            virNetMessageFree(msg);
            ret = -1;
        }
335 336 337 338 339
        msg = tmp;
    }

    virStreamFree(stream->st);
    VIR_FREE(stream);
340 341

    return ret;
342 343 344 345 346 347 348
}


/*
 * @client: a locked client to add the stream to
 * @stream: a stream to add
 */
349 350 351
int daemonAddClientStream(virNetServerClientPtr client,
                          daemonClientStream *stream,
                          bool transmit)
352
{
353 354 355
    VIR_DEBUG("client=%p, proc=%d, serial=%d, st=%p, transmit=%d",
              client, stream->procedure, stream->serial, stream->st, transmit);
    daemonClientPrivatePtr priv = virNetServerClientGetPrivateData(client);
356

357 358 359 360
    if (stream->filterID != -1) {
        VIR_WARN("Filter already added to client %p", client);
        return -1;
    }
361

362
    if (virStreamEventAddCallback(stream->st, 0,
363
                                  daemonStreamEvent, client, NULL) < 0)
364 365
        return -1;

366 367 368 369 370
    if ((stream->filterID = virNetServerClientAddFilter(client,
                                                        daemonStreamFilter,
                                                        stream)) < 0) {
        virStreamEventRemoveCallback(stream->st);
        return -1;
371 372
    }

373 374
    if (transmit)
        stream->tx = 1;
375

376 377 378
    virMutexLock(&priv->lock);
    stream->next = priv->streams;
    priv->streams = stream;
379

380
    daemonStreamUpdateEvents(stream);
381

382
    virMutexUnlock(&priv->lock);
383

384
    return 0;
385 386 387 388 389 390 391 392 393 394 395 396
}


/*
 * @client: a locked client object
 * @stream: an inactive, closed stream object
 *
 * Removes a stream from the list of active streams for the client
 *
 * Returns 0 if the stream was removd, -1 if it doesn't exist
 */
int
397 398
daemonRemoveClientStream(virNetServerClientPtr client,
                         daemonClientStream *stream)
399
{
400 401 402 403 404 405 406 407 408 409
    VIR_DEBUG("client=%p, proc=%d, serial=%d, st=%p",
              client, stream->procedure, stream->serial, stream->st);
    daemonClientPrivatePtr priv = virNetServerClientGetPrivateData(client);
    daemonClientStream *curr = priv->streams;
    daemonClientStream *prev = NULL;

    if (stream->filterID != -1) {
        virNetServerClientRemoveFilter(client,
                                       stream->filterID);
        stream->filterID = -1;
410 411
    }

412 413
    if (!stream->closed) {
        virStreamEventRemoveCallback(stream->st);
414
        virStreamAbort(stream->st);
415
    }
416 417 418 419 420 421

    while (curr) {
        if (curr == stream) {
            if (prev)
                prev->next = curr->next;
            else
422 423
                priv->streams = curr->next;
            return daemonFreeClientStream(client, stream);
424 425 426 427 428 429
        }
        prev = curr;
        curr = curr->next;
    }
    return -1;
}
430 431 432 433 434 435 436 437 438


/*
 * Returns:
 *   -1  if fatal error occurred
 *    0  if message was fully processed
 *    1  if message is still being processed
 */
static int
439 440 441
daemonStreamHandleWriteData(virNetServerClientPtr client,
                            daemonClientStream *stream,
                            virNetMessagePtr msg)
442 443 444
{
    int ret;

445 446 447
    VIR_DEBUG("client=%p, stream=%p, proc=%d, serial=%d, len=%zu, offset=%zu",
              client, stream, msg->header.proc, msg->header.serial,
              msg->bufferLength, msg->bufferOffset);
448 449 450 451 452 453 454 455 456 457 458

    ret = virStreamSend(stream->st,
                        msg->buffer + msg->bufferOffset,
                        msg->bufferLength - msg->bufferOffset);

    if (ret > 0) {
        msg->bufferOffset += ret;

        /* Partial write, so indicate we have more todo later */
        if (msg->bufferOffset < msg->bufferLength)
            return 1;
459 460 461 462

        /* A dummy 'send' just to free up 'msg' object */
        memset(msg, 0, sizeof(*msg));
        return virNetServerClientSendMessage(client, msg);
463 464 465 466
    } else if (ret == -2) {
        /* Blocking, so indicate we have more todo later */
        return 1;
    } else {
467 468 469 470
        virNetMessageError rerr;

        memset(&rerr, 0, sizeof(rerr));

471
        VIR_INFO("Stream send failed");
472
        stream->closed = 1;
473 474 475 476 477
        return virNetServerProgramSendReplyError(stream->prog,
                                                 client,
                                                 msg,
                                                 &rerr,
                                                 &msg->header);
478 479 480 481 482 483 484 485 486
    }

    return 0;
}


/*
 * Process an finish handshake from the client.
 *
487
 * Returns a VIR_NET_OK confirmation if successful, or a VIR_NET_ERROR
488 489 490 491 492
 * if there was a stream error
 *
 * Returns 0 if successfully sent RPC reply, -1 upon fatal error
 */
static int
493 494 495
daemonStreamHandleFinish(virNetServerClientPtr client,
                         daemonClientStream *stream,
                         virNetMessagePtr msg)
496 497 498
{
    int ret;

499 500
    VIR_DEBUG("client=%p, stream=%p, proc=%d, serial=%d",
              client, stream, msg->header.proc, msg->header.serial);
501 502

    stream->closed = 1;
503
    virStreamEventRemoveCallback(stream->st);
504 505 506
    ret = virStreamFinish(stream->st);

    if (ret < 0) {
507 508 509 510 511 512 513
        virNetMessageError rerr;
        memset(&rerr, 0, sizeof(rerr));
        return virNetServerProgramSendReplyError(stream->prog,
                                                 client,
                                                 msg,
                                                 &rerr,
                                                 &msg->header);
514 515
    } else {
        /* Send zero-length confirm */
516 517 518 519 520 521
        return virNetServerProgramSendStreamData(stream->prog,
                                                 client,
                                                 msg,
                                                 stream->procedure,
                                                 stream->serial,
                                                 NULL, 0);
522 523 524 525 526 527 528 529 530 531
    }
}


/*
 * Process an abort request from the client.
 *
 * Returns 0 if successfully aborted, -1 upon error
 */
static int
532 533 534
daemonStreamHandleAbort(virNetServerClientPtr client,
                        daemonClientStream *stream,
                        virNetMessagePtr msg)
535
{
536 537 538
    VIR_DEBUG("client=%p, stream=%p, proc=%d, serial=%d",
              client, stream, msg->header.proc, msg->header.serial);
    virNetMessageError rerr;
539

540
    memset(&rerr, 0, sizeof(rerr));
541 542

    stream->closed = 1;
543
    virStreamEventRemoveCallback(stream->st);
544 545
    virStreamAbort(stream->st);

546 547 548
    if (msg->header.status == VIR_NET_ERROR)
        virNetError(VIR_ERR_RPC,
                    "%s", _("stream aborted at client request"));
549
    else {
550 551 552 553
        VIR_WARN("unexpected stream status %d", msg->header.status);
        virNetError(VIR_ERR_RPC,
                    _("stream aborted with unexpected status %d"),
                    msg->header.status);
554 555
    }

556 557 558 559 560
    return virNetServerProgramSendReplyError(remoteProgram,
                                             client,
                                             msg,
                                             &rerr,
                                             &msg->header);
561 562 563 564 565 566 567 568 569 570 571 572
}



/*
 * Called when the stream is signalled has being able to accept
 * data writes. Will process all pending incoming messages
 * until they're all gone, or I/O blocks
 *
 * Returns 0 on success, or -1 upon fatal error
 */
static int
573 574
daemonStreamHandleWrite(virNetServerClientPtr client,
                        daemonClientStream *stream)
575
{
576
    VIR_DEBUG("client=%p, stream=%p", client, stream);
577

578 579
    while (stream->rx && !stream->closed) {
        virNetMessagePtr msg = stream->rx;
580
        int ret;
581 582 583 584

        switch (msg->header.status) {
        case VIR_NET_OK:
            ret = daemonStreamHandleFinish(client, stream, msg);
585 586
            break;

587 588
        case VIR_NET_CONTINUE:
            ret = daemonStreamHandleWriteData(client, stream, msg);
589 590
            break;

591
        case VIR_NET_ERROR:
592
        default:
593
            ret = daemonStreamHandleAbort(client, stream, msg);
594 595 596
            break;
        }

597 598
        if (ret > 0)
            break;  /* still processing data from msg */
599

600 601 602 603 604 605
        virNetMessageQueueServe(&stream->rx);
        if (ret < 0) {
            virNetMessageFree(msg);
            virNetServerClientMarkClose(client);
            return -1;
        }
606 607 608 609
    }

    return 0;
}
610 611 612 613 614 615 616 617 618 619 620 621 622 623



/*
 * Invoked when a stream is signalled as having data
 * available to read. This reads upto one message
 * worth of data, and then queues that for transmission
 * to the client.
 *
 * Returns 0 if data was queued for TX, or a error RPC
 * was sent, or -1 on fatal error, indicating client should
 * be killed
 */
static int
624 625
daemonStreamHandleRead(virNetServerClientPtr client,
                       daemonClientStream *stream)
626 627
{
    char *buffer;
628
    size_t bufferLen = VIR_NET_MESSAGE_PAYLOAD_MAX;
629 630
    int ret;

631
    VIR_DEBUG("client=%p, stream=%p", client, stream);
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646

    /* Shouldn't ever be called unless we're marked able to
     * transmit, but doesn't hurt to check */
    if (!stream->tx)
        return 0;

    if (VIR_ALLOC_N(buffer, bufferLen) < 0)
        return -1;

    ret = virStreamRecv(stream->st, buffer, bufferLen);
    if (ret == -2) {
        /* Should never get this, since we're only called when we know
         * we're readable, but hey things change... */
        ret = 0;
    } else if (ret < 0) {
647 648 649 650
        virNetMessagePtr msg;
        virNetMessageError rerr;

        memset(&rerr, 0, sizeof(rerr));
651

652 653 654 655 656 657 658 659 660
        if (!(msg = virNetMessageNew()))
            ret = -1;
        else
            ret = virNetServerProgramSendStreamError(remoteProgram,
                                                     client,
                                                     msg,
                                                     &rerr,
                                                     stream->procedure,
                                                     stream->serial);
661
    } else {
662
        virNetMessagePtr msg;
663 664 665
        stream->tx = 0;
        if (ret == 0)
            stream->recvEOF = 1;
666 667 668 669 670 671 672 673 674 675 676 677 678 679
        if (!(msg = virNetMessageNew()))
            ret = -1;

        if (msg) {
            msg->cb = daemonStreamMessageFinished;
            msg->opaque = stream;
            virNetServerClientRef(client);
            ret = virNetServerProgramSendStreamData(remoteProgram,
                                                    client,
                                                    msg,
                                                    stream->procedure,
                                                    stream->serial,
                                                    buffer, ret);
        }
680 681 682 683 684
    }

    VIR_FREE(buffer);
    return ret;
}