virnetclientstream.c 14.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/*
 * virnetclientstream.c: generic network RPC client stream
 *
 * Copyright (C) 2006-2011 Red Hat, Inc.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
17
 * License along with this library.  If not, see
O
Osier Yang 已提交
18
 * <http://www.gnu.org/licenses/>.
19 20 21 22 23 24 25 26
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */

#include <config.h>

#include "virnetclientstream.h"
#include "virnetclient.h"
27
#include "viralloc.h"
28
#include "virerror.h"
29
#include "virlog.h"
30
#include "virthread.h"
31 32 33

#define VIR_FROM_THIS VIR_FROM_RPC

34 35
VIR_LOG_INIT("rpc.netclientstream");

36
struct _virNetClientStream {
37
    virObjectLockable parent;
38

39 40 41 42 43 44 45 46 47 48 49 50 51
    virNetClientProgramPtr prog;
    int proc;
    unsigned serial;

    virError err;

    /* XXX this buffer is unbounded if the client
     * app has domain events registered, since packets
     * may be read off wire, while app isn't ready to
     * recv them. Figure out how to address this some
     * time by stopping consuming any incoming data
     * off the socket....
     */
52 53 54
    struct iovec *incomingVec; /* I/O Vector to hold data */
    size_t writeVec;           /* Vectors produced */
    size_t readVec;            /* Vectors consumed */
55
    bool incomingEOF;
56 57 58 59 60 61 62 63 64 65

    virNetClientStreamEventCallback cb;
    void *cbOpaque;
    virFreeCallback cbFree;
    int cbEvents;
    int cbTimer;
    int cbDispatch;
};


66 67 68 69 70
static virClassPtr virNetClientStreamClass;
static void virNetClientStreamDispose(void *obj);

static int virNetClientStreamOnceInit(void)
{
71
    if (!(virNetClientStreamClass = virClassNew(virClassForObjectLockable(),
72
                                                "virNetClientStream",
73 74 75 76 77 78 79 80 81 82
                                                sizeof(virNetClientStream),
                                                virNetClientStreamDispose)))
        return -1;

    return 0;
}

VIR_ONCE_GLOBAL_INIT(virNetClientStream)


83 84 85 86 87 88
static void
virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st)
{
    if (!st->cb)
        return;

89
    VIR_DEBUG("Check timer readVec %zu writeVec %zu %d", st->readVec, st->writeVec, st->cbEvents);
90

91
    if ((((st->readVec < st->writeVec) || st->incomingEOF) &&
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
         (st->cbEvents & VIR_STREAM_EVENT_READABLE)) ||
        (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) {
        VIR_DEBUG("Enabling event timer");
        virEventUpdateTimeout(st->cbTimer, 0);
    } else {
        VIR_DEBUG("Disabling event timer");
        virEventUpdateTimeout(st->cbTimer, -1);
    }
}


static void
virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
{
    virNetClientStreamPtr st = opaque;
    int events = 0;

109
    virObjectLock(st);
110 111 112

    if (st->cb &&
        (st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
113
        ((st->readVec < st->writeVec) || st->incomingEOF))
114 115 116 117 118
        events |= VIR_STREAM_EVENT_READABLE;
    if (st->cb &&
        (st->cbEvents & VIR_STREAM_EVENT_WRITABLE))
        events |= VIR_STREAM_EVENT_WRITABLE;

119 120
    VIR_DEBUG("Got Timer dispatch %d %d readVec %zu writeVec %zu", events, st->cbEvents,
              st->readVec, st->writeVec);
121 122 123 124 125 126
    if (events) {
        virNetClientStreamEventCallback cb = st->cb;
        void *cbOpaque = st->cbOpaque;
        virFreeCallback cbFree = st->cbFree;

        st->cbDispatch = 1;
127
        virObjectUnlock(st);
128
        (cb)(st, events, cbOpaque);
129
        virObjectLock(st);
130 131 132 133 134
        st->cbDispatch = 0;

        if (!st->cb && cbFree)
            (cbFree)(cbOpaque);
    }
135
    virObjectUnlock(st);
136 137 138 139 140 141 142 143 144
}


virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog,
                                            int proc,
                                            unsigned serial)
{
    virNetClientStreamPtr st;

145 146 147
    if (virNetClientStreamInitialize() < 0)
        return NULL;

148
    if (!(st = virObjectLockableNew(virNetClientStreamClass)))
149 150 151 152 153 154
        return NULL;

    st->prog = prog;
    st->proc = proc;
    st->serial = serial;

155
    virObjectRef(prog);
156

157 158 159
    return st;
}

160
void virNetClientStreamDispose(void *obj)
161
{
162
    virNetClientStreamPtr st = obj;
163 164

    virResetError(&st->err);
165
    VIR_FREE(st->incomingVec);
166
    virObjectUnref(st->prog);
167 168 169 170 171
}

bool virNetClientStreamMatches(virNetClientStreamPtr st,
                               virNetMessagePtr msg)
{
172
    bool match = false;
173
    virObjectLock(st);
174 175 176
    if (virNetClientProgramMatches(st->prog, msg) &&
        st->proc == msg->header.proc &&
        st->serial == msg->header.serial)
177
        match = true;
178
    virObjectUnlock(st);
179
    return match;
180 181 182 183 184
}


bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
{
185
    virObjectLock(st);
186
    if (st->err.code == VIR_ERR_OK) {
187
        virObjectUnlock(st);
188
        return false;
189
    }
190 191 192 193 194 195 196 197 198 199 200

    virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
                      st->err.domain,
                      st->err.code,
                      st->err.level,
                      st->err.str1,
                      st->err.str2,
                      st->err.str3,
                      st->err.int1,
                      st->err.int2,
                      "%s", st->err.message ? st->err.message : _("Unknown error"));
201
    virObjectUnlock(st);
202 203 204 205 206 207 208 209 210 211
    return true;
}


int virNetClientStreamSetError(virNetClientStreamPtr st,
                               virNetMessagePtr msg)
{
    virNetMessageError err;
    int ret = -1;

212
    virObjectLock(st);
213

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
    if (st->err.code != VIR_ERR_OK)
        VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));

    virResetError(&st->err);
    memset(&err, 0, sizeof(err));

    if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
        goto cleanup;

    if (err.domain == VIR_FROM_REMOTE &&
        err.code == VIR_ERR_RPC &&
        err.level == VIR_ERR_ERROR &&
        err.message &&
        STRPREFIX(*err.message, "unknown procedure")) {
        st->err.code = VIR_ERR_NO_SUPPORT;
    } else {
        st->err.code = err.code;
    }
232 233 234 235
    if (err.message) {
        st->err.message = *err.message;
        *err.message = NULL;
    }
236 237
    st->err.domain = err.domain;
    st->err.level = err.level;
238 239 240 241 242 243 244 245 246 247 248 249
    if (err.str1) {
        st->err.str1 = *err.str1;
        *err.str1 = NULL;
    }
    if (err.str2) {
        st->err.str2 = *err.str2;
        *err.str2 = NULL;
    }
    if (err.str3) {
        st->err.str3 = *err.str3;
        *err.str3 = NULL;
    }
250 251 252
    st->err.int1 = err.int1;
    st->err.int2 = err.int2;

253 254 255
    st->incomingEOF = true;
    virNetClientStreamEventTimerUpdate(st);

256 257
    ret = 0;

258
 cleanup:
259
    xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
260
    virObjectUnlock(st);
261 262 263 264 265 266 267
    return ret;
}


int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
                                  virNetMessagePtr msg)
{
268
    int ret = -1;
269 270 271
    struct iovec iov;
    char *base;
    size_t piece, pieces, length, offset = 0, size = 1024*1024;
272

273
    virObjectLock(st);
274

275 276 277
    length = msg->bufferLength - msg->bufferOffset;

    if (length == 0) {
278
        st->incomingEOF = true;
279
        goto end;
280
    }
281

282 283 284 285 286 287 288 289 290 291 292 293 294 295
    pieces = (length + size - 1) / size;
    for (piece = 0; piece < pieces; piece++) {
        if (size > length - offset)
            size = length - offset;

        if (VIR_ALLOC_N(base, size)) {
            VIR_DEBUG("Allocation failed");
            goto cleanup;
        }

        memcpy(base, msg->buffer + msg->bufferOffset + offset, size);
        iov.iov_base = base;
        iov.iov_len = size;
        offset += size;
296

297 298 299 300 301 302 303 304 305 306 307
        if (VIR_APPEND_ELEMENT(st->incomingVec, st->writeVec, iov) < 0) {
            VIR_DEBUG("Append failed");
            VIR_FREE(base);
            goto cleanup;
        }
        VIR_DEBUG("Wrote piece of vector. readVec %zu, writeVec %zu size %zu",
                  st->readVec, st->writeVec, size);
    }

 end:
    virNetClientStreamEventTimerUpdate(st);
308 309
    ret = 0;

310
 cleanup:
311 312
    VIR_DEBUG("Stream incoming data readVec %zu writeVec %zu EOF %d",
              st->readVec, st->writeVec, st->incomingEOF);
313
    virObjectUnlock(st);
314
    return ret;
315 316 317 318 319 320 321 322 323 324 325 326
}


int virNetClientStreamSendPacket(virNetClientStreamPtr st,
                                 virNetClientPtr client,
                                 int status,
                                 const char *data,
                                 size_t nbytes)
{
    virNetMessagePtr msg;
    VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes);

327
    if (!(msg = virNetMessageNew(false)))
328 329
        return -1;

330
    virObjectLock(st);
331

332 333 334 335 336 337 338
    msg->header.prog = virNetClientProgramGetProgram(st->prog);
    msg->header.vers = virNetClientProgramGetVersion(st->prog);
    msg->header.status = status;
    msg->header.type = VIR_NET_STREAM;
    msg->header.serial = st->serial;
    msg->header.proc = st->proc;

339
    virObjectUnlock(st);
340

341 342 343 344 345 346 347 348 349
    if (virNetMessageEncodeHeader(msg) < 0)
        goto error;

    /* Data packets are async fire&forget, but OK/ERROR packets
     * need a synchronous confirmation
     */
    if (status == VIR_NET_CONTINUE) {
        if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0)
            goto error;
350 351 352

        if (virNetClientSendNoReply(client, msg) < 0)
            goto error;
353 354 355
    } else {
        if (virNetMessageEncodePayloadRaw(msg, NULL, 0) < 0)
            goto error;
356 357 358

        if (virNetClientSendWithReply(client, msg) < 0)
            goto error;
359 360 361
    }


362
    virNetMessageFree(msg);
363 364 365

    return nbytes;

366
 error:
367
    virNetMessageFree(msg);
368 369 370 371 372 373 374 375 376
    return -1;
}

int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
                                 virNetClientPtr client,
                                 char *data,
                                 size_t nbytes,
                                 bool nonblock)
{
377 378 379 380 381
    int ret = -1;
    size_t partial, offset;

    virObjectLock(st);

382 383
    VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d",
              st, client, data, nbytes, nonblock);
384 385

    if ((st->readVec >= st->writeVec) && !st->incomingEOF) {
386
        virNetMessagePtr msg;
387
        int rv;
388 389 390

        if (nonblock) {
            VIR_DEBUG("Non-blocking mode and no data available");
391
            ret = -2;
392 393 394
            goto cleanup;
        }

395
        if (!(msg = virNetMessageNew(false)))
396 397 398 399 400 401 402
            goto cleanup;

        msg->header.prog = virNetClientProgramGetProgram(st->prog);
        msg->header.vers = virNetClientProgramGetVersion(st->prog);
        msg->header.type = VIR_NET_STREAM;
        msg->header.serial = st->serial;
        msg->header.proc = st->proc;
403
        msg->header.status = VIR_NET_CONTINUE;
404 405

        VIR_DEBUG("Dummy packet to wait for stream data");
406
        virObjectUnlock(st);
407
        rv = virNetClientSendWithReplyStream(client, msg, st);
408
        virObjectLock(st);
409 410
        virNetMessageFree(msg);

411
        if (rv < 0)
412 413 414
            goto cleanup;
    }

415 416 417 418 419 420 421 422 423 424
    offset = 0;
    partial = nbytes;

    while (st->incomingVec && (st->readVec < st->writeVec)) {
        struct iovec *iov = st->incomingVec + st->readVec;

        if (!iov || !iov->iov_base) {
            virReportError(VIR_ERR_INTERNAL_ERROR,
                           "%s", _("NULL pointer encountered"));
            goto cleanup;
425
        }
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446

        if (partial < iov->iov_len) {
            memcpy(data+offset, iov->iov_base, partial);
            memmove(iov->iov_base, (char*)iov->iov_base+partial,
                    iov->iov_len-partial);
            iov->iov_len -= partial;
            offset += partial;
            VIR_DEBUG("Consumed %zu, left %zu", partial, iov->iov_len);
            break;
        }

        memcpy(data+offset, iov->iov_base, iov->iov_len);
        VIR_DEBUG("Consumed %zu. Moving to next piece", iov->iov_len);
        partial -= iov->iov_len;
        offset += iov->iov_len;
        VIR_FREE(iov->iov_base);
        iov->iov_len = 0;
        st->readVec++;

        VIR_DEBUG("Read piece of vector. read %zu, readVec %zu, writeVec %zu",
                  offset, st->readVec, st->writeVec);
447 448
    }

449 450 451 452 453 454 455 456 457 458 459 460 461
    /* Shrink the I/O Vector buffer to free up memory. Do the
       shrinking only when there is selected amount or more buffers to
       free so it doesn't constantly memmove() and realloc() buffers.
     */
    if (st->readVec >= 16) {
        memmove(st->incomingVec, st->incomingVec + st->readVec,
                sizeof(*st->incomingVec)*(st->writeVec - st->readVec));
        VIR_SHRINK_N(st->incomingVec, st->writeVec, st->readVec);
        VIR_DEBUG("shrink removed %zu, left %zu", st->readVec, st->writeVec);
        st->readVec = 0;
    }

    ret = offset;
462 463
    virNetClientStreamEventTimerUpdate(st);

464
 cleanup:
465
    virObjectUnlock(st);
466
    return ret;
467 468 469 470 471 472 473 474 475
}


int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
                                       int events,
                                       virNetClientStreamEventCallback cb,
                                       void *opaque,
                                       virFreeCallback ff)
{
476 477
    int ret = -1;

478
    virObjectLock(st);
479
    if (st->cb) {
480 481
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "%s", _("multiple stream callbacks not supported"));
482
        goto cleanup;
483 484
    }

485
    virObjectRef(st);
486 487 488 489
    if ((st->cbTimer =
         virEventAddTimeout(-1,
                            virNetClientStreamEventTimer,
                            st,
490 491
                            virObjectFreeCallback)) < 0) {
        virObjectUnref(st);
492
        goto cleanup;
493 494 495 496 497 498 499 500 501
    }

    st->cb = cb;
    st->cbOpaque = opaque;
    st->cbFree = ff;
    st->cbEvents = events;

    virNetClientStreamEventTimerUpdate(st);

502 503
    ret = 0;

504
 cleanup:
505
    virObjectUnlock(st);
506
    return ret;
507 508 509 510 511
}

int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
                                          int events)
{
512 513
    int ret = -1;

514
    virObjectLock(st);
515
    if (!st->cb) {
516 517
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "%s", _("no stream callback registered"));
518
        goto cleanup;
519 520 521 522 523 524
    }

    st->cbEvents = events;

    virNetClientStreamEventTimerUpdate(st);

525 526
    ret = 0;

527
 cleanup:
528
    virObjectUnlock(st);
529
    return ret;
530 531 532 533
}

int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
{
534 535
    int ret = -1;

536
    virObjectLock(st);
537
    if (!st->cb) {
538 539
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "%s", _("no stream callback registered"));
540
        goto cleanup;
541 542 543 544 545 546 547 548 549 550 551
    }

    if (!st->cbDispatch &&
        st->cbFree)
        (st->cbFree)(st->cbOpaque);
    st->cb = NULL;
    st->cbOpaque = NULL;
    st->cbFree = NULL;
    st->cbEvents = 0;
    virEventRemoveTimeout(st->cbTimer);

552 553
    ret = 0;

554
 cleanup:
555
    virObjectUnlock(st);
556
    return ret;
557
}
M
Michal Privoznik 已提交
558 559 560 561 562

bool virNetClientStreamEOF(virNetClientStreamPtr st)
{
    return st->incomingEOF;
}