virnetclientstream.c 19.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/*
 * virnetclientstream.c: generic network RPC client stream
 *
 * Copyright (C) 2006-2011 Red Hat, Inc.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
17
 * License along with this library.  If not, see
O
Osier Yang 已提交
18
 * <http://www.gnu.org/licenses/>.
19 20 21 22 23 24
 */

#include <config.h>

#include "virnetclientstream.h"
#include "virnetclient.h"
25
#include "viralloc.h"
26
#include "virerror.h"
27
#include "virlog.h"
28
#include "virthread.h"
29 30 31

#define VIR_FROM_THIS VIR_FROM_RPC

32 33
VIR_LOG_INIT("rpc.netclientstream");

34
struct _virNetClientStream {
35
    virObjectLockable parent;
36

37 38
    virStreamPtr stream; /* Reverse pointer to parent stream */

39 40 41 42 43 44 45 46 47 48 49 50 51
    virNetClientProgramPtr prog;
    int proc;
    unsigned serial;

    virError err;

    /* XXX this buffer is unbounded if the client
     * app has domain events registered, since packets
     * may be read off wire, while app isn't ready to
     * recv them. Figure out how to address this some
     * time by stopping consuming any incoming data
     * off the socket....
     */
52
    virNetMessagePtr rx;
53
    bool incomingEOF;
54

55
    bool allowSkip;
56
    long long holeLength;  /* Size of incoming hole in stream. */
57

58 59 60 61 62 63 64 65 66
    virNetClientStreamEventCallback cb;
    void *cbOpaque;
    virFreeCallback cbFree;
    int cbEvents;
    int cbTimer;
    int cbDispatch;
};


67 68 69 70 71
static virClassPtr virNetClientStreamClass;
static void virNetClientStreamDispose(void *obj);

static int virNetClientStreamOnceInit(void)
{
72
    if (!VIR_CLASS_NEW(virNetClientStream, virClassForObjectLockable()))
73 74 75 76 77
        return -1;

    return 0;
}

78
VIR_ONCE_GLOBAL_INIT(virNetClientStream);
79 80


81 82 83 84 85 86
static void
virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st)
{
    if (!st->cb)
        return;

87
    VIR_DEBUG("Check timer rx=%p cbEvents=%d", st->rx, st->cbEvents);
88

89
    if (((st->rx || st->incomingEOF) &&
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
         (st->cbEvents & VIR_STREAM_EVENT_READABLE)) ||
        (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) {
        VIR_DEBUG("Enabling event timer");
        virEventUpdateTimeout(st->cbTimer, 0);
    } else {
        VIR_DEBUG("Disabling event timer");
        virEventUpdateTimeout(st->cbTimer, -1);
    }
}


static void
virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
{
    virNetClientStreamPtr st = opaque;
    int events = 0;

107
    virObjectLock(st);
108 109 110

    if (st->cb &&
        (st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
111
        (st->rx || st->incomingEOF))
112 113 114 115 116
        events |= VIR_STREAM_EVENT_READABLE;
    if (st->cb &&
        (st->cbEvents & VIR_STREAM_EVENT_WRITABLE))
        events |= VIR_STREAM_EVENT_WRITABLE;

117
    VIR_DEBUG("Got Timer dispatch events=%d cbEvents=%d rx=%p", events, st->cbEvents, st->rx);
118 119 120 121 122 123
    if (events) {
        virNetClientStreamEventCallback cb = st->cb;
        void *cbOpaque = st->cbOpaque;
        virFreeCallback cbFree = st->cbFree;

        st->cbDispatch = 1;
124
        virObjectUnlock(st);
125
        (cb)(st, events, cbOpaque);
126
        virObjectLock(st);
127 128 129 130 131
        st->cbDispatch = 0;

        if (!st->cb && cbFree)
            (cbFree)(cbOpaque);
    }
132
    virObjectUnlock(st);
133 134 135
}


136 137
virNetClientStreamPtr virNetClientStreamNew(virStreamPtr stream,
                                            virNetClientProgramPtr prog,
138
                                            int proc,
139 140
                                            unsigned serial,
                                            bool allowSkip)
141 142 143
{
    virNetClientStreamPtr st;

144 145 146
    if (virNetClientStreamInitialize() < 0)
        return NULL;

147
    if (!(st = virObjectLockableNew(virNetClientStreamClass)))
148 149
        return NULL;

150
    st->stream = virObjectRef(stream);
151
    st->prog = virObjectRef(prog);
152 153
    st->proc = proc;
    st->serial = serial;
154
    st->allowSkip = allowSkip;
155 156 157 158

    return st;
}

159
void virNetClientStreamDispose(void *obj)
160
{
161
    virNetClientStreamPtr st = obj;
162 163

    virResetError(&st->err);
164 165 166 167 168
    while (st->rx) {
        virNetMessagePtr msg = st->rx;
        virNetMessageQueueServe(&st->rx);
        virNetMessageFree(msg);
    }
169
    virObjectUnref(st->prog);
170
    virObjectUnref(st->stream);
171 172 173 174 175
}

bool virNetClientStreamMatches(virNetClientStreamPtr st,
                               virNetMessagePtr msg)
{
176
    bool match = false;
177
    virObjectLock(st);
178 179 180
    if (virNetClientProgramMatches(st->prog, msg) &&
        st->proc == msg->header.proc &&
        st->serial == msg->header.serial)
181
        match = true;
182
    virObjectUnlock(st);
183
    return match;
184 185 186 187 188
}


bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
{
189
    virObjectLock(st);
190
    if (st->err.code == VIR_ERR_OK) {
191
        virObjectUnlock(st);
192
        return false;
193
    }
194 195 196 197 198 199 200 201 202 203 204

    virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
                      st->err.domain,
                      st->err.code,
                      st->err.level,
                      st->err.str1,
                      st->err.str2,
                      st->err.str3,
                      st->err.int1,
                      st->err.int2,
                      "%s", st->err.message ? st->err.message : _("Unknown error"));
205
    virObjectUnlock(st);
206 207 208 209 210 211 212 213 214 215
    return true;
}


int virNetClientStreamSetError(virNetClientStreamPtr st,
                               virNetMessagePtr msg)
{
    virNetMessageError err;
    int ret = -1;

216
    virObjectLock(st);
217

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
    if (st->err.code != VIR_ERR_OK)
        VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));

    virResetError(&st->err);
    memset(&err, 0, sizeof(err));

    if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
        goto cleanup;

    if (err.domain == VIR_FROM_REMOTE &&
        err.code == VIR_ERR_RPC &&
        err.level == VIR_ERR_ERROR &&
        err.message &&
        STRPREFIX(*err.message, "unknown procedure")) {
        st->err.code = VIR_ERR_NO_SUPPORT;
    } else {
        st->err.code = err.code;
    }
236 237 238 239
    if (err.message) {
        st->err.message = *err.message;
        *err.message = NULL;
    }
240 241
    st->err.domain = err.domain;
    st->err.level = err.level;
242 243 244 245 246 247 248 249 250 251 252 253
    if (err.str1) {
        st->err.str1 = *err.str1;
        *err.str1 = NULL;
    }
    if (err.str2) {
        st->err.str2 = *err.str2;
        *err.str2 = NULL;
    }
    if (err.str3) {
        st->err.str3 = *err.str3;
        *err.str3 = NULL;
    }
254 255 256
    st->err.int1 = err.int1;
    st->err.int2 = err.int2;

257 258 259
    st->incomingEOF = true;
    virNetClientStreamEventTimerUpdate(st);

260 261
    ret = 0;

262
 cleanup:
263
    xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
264
    virObjectUnlock(st);
265 266 267 268 269 270 271
    return ret;
}


int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
                                  virNetMessagePtr msg)
{
272 273 274 275
    virNetMessagePtr tmp_msg;

    VIR_DEBUG("Incoming stream message: stream=%p message=%p", st, msg);

276 277 278 279 280 281 282 283 284
    if (msg->bufferLength == msg->bufferOffset) {
        /* No payload means end of the stream. */
        virObjectLock(st);
        st->incomingEOF = true;
        virNetClientStreamEventTimerUpdate(st);
        virObjectUnlock(st);
        return 0;
    }

285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
    /* Unfortunately, we must allocate new message as the one we
     * get in @msg is going to be cleared later in the process. */

    if (!(tmp_msg = virNetMessageNew(false)))
        return -1;

    /* Copy header */
    memcpy(&tmp_msg->header, &msg->header, sizeof(msg->header));

    /* Steal message buffer */
    tmp_msg->buffer = msg->buffer;
    tmp_msg->bufferLength = msg->bufferLength;
    tmp_msg->bufferOffset = msg->bufferOffset;
    msg->buffer = NULL;
    msg->bufferLength = msg->bufferOffset = 0;
300

301
    virObjectLock(st);
302

303 304
    /* Don't distinguish VIR_NET_STREAM and VIR_NET_STREAM_SKIP
     * here just yet. We want in order processing! */
305
    virNetMessageQueuePush(&st->rx, tmp_msg);
306 307

    virNetClientStreamEventTimerUpdate(st);
308

309
    virObjectUnlock(st);
310
    return 0;
311 312 313 314 315 316 317 318 319 320 321 322
}


int virNetClientStreamSendPacket(virNetClientStreamPtr st,
                                 virNetClientPtr client,
                                 int status,
                                 const char *data,
                                 size_t nbytes)
{
    virNetMessagePtr msg;
    VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes);

323
    if (!(msg = virNetMessageNew(false)))
324 325
        return -1;

326
    virObjectLock(st);
327

328 329 330 331 332 333 334
    msg->header.prog = virNetClientProgramGetProgram(st->prog);
    msg->header.vers = virNetClientProgramGetVersion(st->prog);
    msg->header.status = status;
    msg->header.type = VIR_NET_STREAM;
    msg->header.serial = st->serial;
    msg->header.proc = st->proc;

335
    virObjectUnlock(st);
336

337 338 339 340 341 342 343 344 345 346 347 348 349 350
    if (virNetMessageEncodeHeader(msg) < 0)
        goto error;

    /* Data packets are async fire&forget, but OK/ERROR packets
     * need a synchronous confirmation
     */
    if (status == VIR_NET_CONTINUE) {
        if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0)
            goto error;
    } else {
        if (virNetMessageEncodePayloadRaw(msg, NULL, 0) < 0)
            goto error;
    }

351 352
    if (virNetClientSendStream(client, msg, st) < 0)
        goto error;
353

354
    virNetMessageFree(msg);
355 356 357

    return nbytes;

358
 error:
359
    virNetMessageFree(msg);
360 361 362
    return -1;
}

363 364 365 366 367 368 369

static int
virNetClientStreamSetHole(virNetClientStreamPtr st,
                          long long length,
                          unsigned int flags)
{
    virCheckFlags(0, -1);
370
    virCheckPositiveArgReturn(length, -1);
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396

    /* Shouldn't happen, But it's better to safe than sorry. */
    if (st->holeLength) {
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       _("unprocessed hole of size %lld already in the queue"),
                       st->holeLength);
        return -1;
    }

    st->holeLength += length;
    return 0;
}


/**
 * virNetClientStreamHandleHole:
 * @client: client
 * @st: stream
 *
 * Called whenever current message processed in the stream is
 * VIR_NET_STREAM_HOLE. The stream @st is expected to be locked
 * already.
 *
 * Returns: 0 on success,
 *          -1 otherwise.
 */
397
static int
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
virNetClientStreamHandleHole(virNetClientPtr client,
                             virNetClientStreamPtr st)
{
    virNetMessagePtr msg;
    virNetStreamHole data;
    int ret = -1;

    VIR_DEBUG("client=%p st=%p", client, st);

    msg = st->rx;
    memset(&data, 0, sizeof(data));

    /* We should not be called unless there's VIR_NET_STREAM_HOLE
     * message at the head of the list. But doesn't hurt to check */
    if (!msg) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("No message in the queue"));
        goto cleanup;
    }

    if (msg->header.type != VIR_NET_STREAM_HOLE) {
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       _("Invalid message prog=%d type=%d serial=%u proc=%d"),
                       msg->header.prog,
                       msg->header.type,
                       msg->header.serial,
                       msg->header.proc);
        goto cleanup;
    }

    /* Server should not send us VIR_NET_STREAM_HOLE unless we
     * have requested so. But does not hurt to check ... */
    if (!st->allowSkip) {
        virReportError(VIR_ERR_RPC, "%s",
                       _("Unexpected stream hole"));
        goto cleanup;
    }

    if (virNetMessageDecodePayload(msg,
437
                                   (xdrproc_t)xdr_virNetStreamHole,
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
                                   &data) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Malformed stream hole packet"));
        goto cleanup;
    }

    virNetMessageQueueServe(&st->rx);
    virNetMessageFree(msg);

    if (virNetClientStreamSetHole(st, data.length, data.flags) < 0)
        goto cleanup;

    ret = 0;
 cleanup:
    if (ret < 0) {
        /* Abort stream? */
    }
    return ret;
}


459 460 461 462
int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
                                 virNetClientPtr client,
                                 char *data,
                                 size_t nbytes,
463 464
                                 bool nonblock,
                                 unsigned int flags)
465
{
466
    int rv = -1;
467 468
    size_t want;

469
    VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d flags=0x%x",
470 471
              st, client, data, nbytes, nonblock, flags);

472
    virCheckFlags(VIR_STREAM_RECV_STOP_AT_HOLE, -1);
473

474
    virObjectLock(st);
475 476

 reread:
477
    if (!st->rx && !st->incomingEOF) {
478
        virNetMessagePtr msg;
479
        int ret;
480 481 482

        if (nonblock) {
            VIR_DEBUG("Non-blocking mode and no data available");
483
            rv = -2;
484 485 486
            goto cleanup;
        }

487
        if (!(msg = virNetMessageNew(false)))
488 489 490 491 492 493 494
            goto cleanup;

        msg->header.prog = virNetClientProgramGetProgram(st->prog);
        msg->header.vers = virNetClientProgramGetVersion(st->prog);
        msg->header.type = VIR_NET_STREAM;
        msg->header.serial = st->serial;
        msg->header.proc = st->proc;
495
        msg->header.status = VIR_NET_CONTINUE;
496 497

        VIR_DEBUG("Dummy packet to wait for stream data");
498
        virObjectUnlock(st);
499
        ret = virNetClientSendStream(client, msg, st);
500
        virObjectLock(st);
501 502
        virNetMessageFree(msg);

503
        if (ret < 0)
504 505 506
            goto cleanup;
    }

507
    VIR_DEBUG("After IO rx=%p", st->rx);
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529

    if (st->rx &&
        st->rx->header.type == VIR_NET_STREAM_HOLE &&
        st->holeLength == 0) {
        /* Handle skip sent to us by server. */

        if (virNetClientStreamHandleHole(client, st) < 0)
            goto cleanup;
    }

    if (!st->rx && !st->incomingEOF && st->holeLength == 0) {
        if (nonblock) {
            VIR_DEBUG("Non-blocking mode and no data available");
            rv = -2;
            goto cleanup;
        }

        /* We have consumed all packets from incoming queue but those
         * were only skip packets, no data. Read the stream again. */
        goto reread;
    }

530
    want = nbytes;
531 532 533 534 535

    if (st->holeLength) {
        /* Pretend holeLength zeroes was read from stream. */
        size_t len = want;

536 537 538 539 540 541 542
        /* Yes, pretend unless we are asked not to. */
        if (flags & VIR_STREAM_RECV_STOP_AT_HOLE) {
            /* No error reporting here. Caller knows what they are doing. */
            rv = -3;
            goto cleanup;
        }

543 544 545 546 547 548 549 550 551 552 553
        if (len > st->holeLength)
            len = st->holeLength;

        memset(data, 0, len);
        st->holeLength -= len;
        want -= len;
    }

    while (want &&
           st->rx &&
           st->rx->header.type == VIR_NET_STREAM) {
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
        virNetMessagePtr msg = st->rx;
        size_t len = want;

        if (len > msg->bufferLength - msg->bufferOffset)
            len = msg->bufferLength - msg->bufferOffset;

        if (!len)
            break;

        memcpy(data + (nbytes - want), msg->buffer + msg->bufferOffset, len);
        want -= len;
        msg->bufferOffset += len;

        if (msg->bufferOffset == msg->bufferLength) {
            virNetMessageQueueServe(&st->rx);
            virNetMessageFree(msg);
570 571
        }
    }
572
    rv = nbytes - want;
573

574 575
    virNetClientStreamEventTimerUpdate(st);

576
 cleanup:
577
    virObjectUnlock(st);
578
    return rv;
579 580 581
}


582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
int
virNetClientStreamSendHole(virNetClientStreamPtr st,
                           virNetClientPtr client,
                           long long length,
                           unsigned int flags)
{
    virNetMessagePtr msg = NULL;
    virNetStreamHole data;
    int ret = -1;

    VIR_DEBUG("st=%p length=%llu", st, length);

    if (!st->allowSkip) {
        virReportError(VIR_ERR_OPERATION_INVALID, "%s",
                       _("Skipping is not supported with this stream"));
        return -1;
    }

    memset(&data, 0, sizeof(data));
    data.length = length;
    data.flags = flags;

    if (!(msg = virNetMessageNew(false)))
        return -1;

    virObjectLock(st);

    msg->header.prog = virNetClientProgramGetProgram(st->prog);
    msg->header.vers = virNetClientProgramGetVersion(st->prog);
    msg->header.status = VIR_NET_CONTINUE;
    msg->header.type = VIR_NET_STREAM_HOLE;
    msg->header.serial = st->serial;
    msg->header.proc = st->proc;

    virObjectUnlock(st);

    if (virNetMessageEncodeHeader(msg) < 0)
        goto cleanup;

    if (virNetMessageEncodePayload(msg,
622
                                   (xdrproc_t)xdr_virNetStreamHole,
623 624 625
                                   &data) < 0)
        goto cleanup;

626
    if (virNetClientSendStream(client, msg, st) < 0)
627 628 629 630 631 632 633 634 635
        goto cleanup;

    ret = 0;
 cleanup:
    virNetMessageFree(msg);
    return ret;
}


636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652
int
virNetClientStreamRecvHole(virNetClientPtr client ATTRIBUTE_UNUSED,
                           virNetClientStreamPtr st,
                           long long *length)
{
    if (!st->allowSkip) {
        virReportError(VIR_ERR_OPERATION_INVALID, "%s",
                       _("Holes are not supported with this stream"));
        return -1;
    }

    *length = st->holeLength;
    st->holeLength = 0;
    return 0;
}


653 654 655 656 657 658
int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
                                       int events,
                                       virNetClientStreamEventCallback cb,
                                       void *opaque,
                                       virFreeCallback ff)
{
659 660
    int ret = -1;

661
    virObjectLock(st);
662
    if (st->cb) {
663 664
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "%s", _("multiple stream callbacks not supported"));
665
        goto cleanup;
666 667
    }

668
    virObjectRef(st);
669 670 671 672
    if ((st->cbTimer =
         virEventAddTimeout(-1,
                            virNetClientStreamEventTimer,
                            st,
673 674
                            virObjectFreeCallback)) < 0) {
        virObjectUnref(st);
675
        goto cleanup;
676 677 678 679 680 681 682 683 684
    }

    st->cb = cb;
    st->cbOpaque = opaque;
    st->cbFree = ff;
    st->cbEvents = events;

    virNetClientStreamEventTimerUpdate(st);

685 686
    ret = 0;

687
 cleanup:
688
    virObjectUnlock(st);
689
    return ret;
690 691 692 693 694
}

int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
                                          int events)
{
695 696
    int ret = -1;

697
    virObjectLock(st);
698
    if (!st->cb) {
699 700
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "%s", _("no stream callback registered"));
701
        goto cleanup;
702 703 704 705 706 707
    }

    st->cbEvents = events;

    virNetClientStreamEventTimerUpdate(st);

708 709
    ret = 0;

710
 cleanup:
711
    virObjectUnlock(st);
712
    return ret;
713 714 715 716
}

int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
{
717 718
    int ret = -1;

719
    virObjectLock(st);
720
    if (!st->cb) {
721 722
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "%s", _("no stream callback registered"));
723
        goto cleanup;
724 725 726 727 728 729 730 731 732 733 734
    }

    if (!st->cbDispatch &&
        st->cbFree)
        (st->cbFree)(st->cbOpaque);
    st->cb = NULL;
    st->cbOpaque = NULL;
    st->cbFree = NULL;
    st->cbEvents = 0;
    virEventRemoveTimeout(st->cbTimer);

735 736
    ret = 0;

737
 cleanup:
738
    virObjectUnlock(st);
739
    return ret;
740
}
M
Michal Privoznik 已提交
741 742 743 744 745

bool virNetClientStreamEOF(virNetClientStreamPtr st)
{
    return st->incomingEOF;
}