virnetserverclient.c 43.1 KB
Newer Older
1 2 3
/*
 * virnetserverclient.c: generic network RPC server client
 *
4
 * Copyright (C) 2006-2014 Red Hat, Inc.
5 6 7 8 9 10 11 12 13 14 15 16 17
 * Copyright (C) 2006 Daniel P. Berrange
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
18
 * License along with this library.  If not, see
O
Osier Yang 已提交
19
 * <http://www.gnu.org/licenses/>.
20 21 22 23 24 25
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */

#include <config.h>

26
#include "internal.h"
27
#if WITH_SASL
28 29 30
# include <sasl/sasl.h>
#endif

31
#include "virnetserver.h"
32 33
#include "virnetserverclient.h"

34
#include "virlog.h"
35
#include "virerror.h"
36
#include "viralloc.h"
37
#include "virthread.h"
38
#include "virkeepalive.h"
39
#include "virprobe.h"
40 41
#include "virstring.h"
#include "virutil.h"
42 43 44

#define VIR_FROM_THIS VIR_FROM_RPC

45 46
VIR_LOG_INIT("rpc.netserverclient");

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/* Allow for filtering of incoming messages to a custom
 * dispatch processing queue, instead of the workers.
 * This allows for certain types of messages to be handled
 * strictly "in order"
 */

typedef struct _virNetServerClientFilter virNetServerClientFilter;
typedef virNetServerClientFilter *virNetServerClientFilterPtr;

struct _virNetServerClientFilter {
    int id;
    virNetServerClientFilterFunc func;
    void *opaque;

    virNetServerClientFilterPtr next;
};


struct _virNetServerClient
{
67
    virObjectLockable parent;
68

69
    unsigned long long id;
70
    bool wantClose;
71
    bool delayedClose;
72 73 74
    virNetSocketPtr sock;
    int auth;
    bool readonly;
75
#if WITH_GNUTLS
76 77
    virNetTLSContextPtr tlsCtxt;
    virNetTLSSessionPtr tls;
78
#endif
79
#if WITH_SASL
80 81
    virNetSASLSessionPtr sasl;
#endif
M
Michal Privoznik 已提交
82 83
    int sockTimer; /* Timer to be fired upon cached data,
                    * so we jump out from poll() immediately */
84

85 86 87

    virIdentityPtr identity;

88 89 90 91 92 93 94
    /* Connection timestamp, i.e. when a client connected to the daemon (UTC).
     * For old clients restored by post-exec-restart, which did not have this
     * attribute, value of 0 (epoch time) is used to indicate we have no
     * information about their connection time.
     */
    time_t conn_time;

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    /* Count of messages in the 'tx' queue,
     * and the server worker pool queue
     * ie RPC calls in progress. Does not count
     * async events which are not used for
     * throttling calculations */
    size_t nrequests;
    size_t nrequests_max;
    /* Zero or one messages being received. Zero if
     * nrequests >= max_clients and throttling */
    virNetMessagePtr rx;
    /* Zero or many messages waiting for transmit
     * back to client, including async events */
    virNetMessagePtr tx;

    /* Filters to capture messages that would otherwise
     * end up on the 'dx' queue */
    virNetServerClientFilterPtr filters;
    int nextFilterID;

    virNetServerClientDispatchFunc dispatchFunc;
    void *dispatchOpaque;

    void *privateData;
118
    virFreeCallback privateDataFreeFunc;
119
    virNetServerClientPrivPreExecRestart privateDataPreExecRestart;
120
    virNetServerClientCloseFunc privateDataCloseFunc;
121 122

    virKeepAlivePtr keepalive;
123 124 125
};


126 127 128 129 130
static virClassPtr virNetServerClientClass;
static void virNetServerClientDispose(void *obj);

static int virNetServerClientOnceInit(void)
{
131
    if (!(virNetServerClientClass = virClassNew(virClassForObjectLockable(),
132
                                                "virNetServerClient",
133 134 135 136 137 138 139 140 141 142
                                                sizeof(virNetServerClient),
                                                virNetServerClientDispose)))
        return -1;

    return 0;
}

VIR_ONCE_GLOBAL_INIT(virNetServerClient)


143 144
static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque);
static void virNetServerClientUpdateEvent(virNetServerClientPtr client);
M
Michal Privoznik 已提交
145
static void virNetServerClientDispatchRead(virNetServerClientPtr client);
146 147
static int virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                               virNetMessagePtr msg);
148 149 150 151 152

/*
 * @client: a locked client object
 */
static int
153 154
virNetServerClientCalculateHandleMode(virNetServerClientPtr client)
{
155 156 157 158
    int mode = 0;


    VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p",
159
#ifdef WITH_GNUTLS
160 161
              client->tls,
              client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1,
162 163 164
#else
              NULL, -1,
#endif
165 166 167 168 169
              client->rx,
              client->tx);
    if (!client->sock || client->wantClose)
        return 0;

170
#if WITH_GNUTLS
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
    if (client->tls) {
        switch (virNetTLSSessionGetHandshakeStatus(client->tls)) {
        case VIR_NET_TLS_HANDSHAKE_RECVING:
            mode |= VIR_EVENT_HANDLE_READABLE;
            break;
        case VIR_NET_TLS_HANDSHAKE_SENDING:
            mode |= VIR_EVENT_HANDLE_WRITABLE;
            break;
        default:
        case VIR_NET_TLS_HANDSHAKE_COMPLETE:
            if (client->rx)
                mode |= VIR_EVENT_HANDLE_READABLE;
            if (client->tx)
                mode |= VIR_EVENT_HANDLE_WRITABLE;
        }
    } else {
187
#endif
188 189
        /* If there is a message on the rx queue, and
         * we're not in middle of a delayedClose, then
190
         * we're wanting more input */
191
        if (client->rx && !client->delayedClose)
192 193 194 195 196 197
            mode |= VIR_EVENT_HANDLE_READABLE;

        /* If there are one or more messages to send back to client,
           then monitor for writability on socket */
        if (client->tx)
            mode |= VIR_EVENT_HANDLE_WRITABLE;
198
#if WITH_GNUTLS
199
    }
200
#endif
201
    VIR_DEBUG("mode=%o", mode);
202 203 204 205 206 207 208 209 210 211 212
    return mode;
}

/*
 * @server: a locked or unlocked server object
 * @client: a locked client object
 */
static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
{
    int mode = virNetServerClientCalculateHandleMode(client);

213 214 215 216
    if (!client->sock)
        return -1;

    virObjectRef(client);
217
    VIR_DEBUG("Registering client event callback %d", mode);
218
    if (virNetSocketAddIOCallback(client->sock,
219 220
                                  mode,
                                  virNetServerClientDispatchEvent,
221
                                  client,
222 223
                                  virObjectFreeCallback) < 0) {
        virObjectUnref(client);
224
        return -1;
225
    }
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

    return 0;
}

/*
 * @client: a locked client object
 */
static void virNetServerClientUpdateEvent(virNetServerClientPtr client)
{
    int mode;

    if (!client->sock)
        return;

    mode = virNetServerClientCalculateHandleMode(client);

    virNetSocketUpdateIOCallback(client->sock, mode);
M
Michal Privoznik 已提交
243 244 245

    if (client->rx && virNetSocketHasCachedData(client->sock))
        virEventUpdateTimeout(client->sockTimer, 0);
246 247 248
}


249 250 251
int virNetServerClientAddFilter(virNetServerClientPtr client,
                                virNetServerClientFilterFunc func,
                                void *opaque)
252 253
{
    virNetServerClientFilterPtr filter;
254
    virNetServerClientFilterPtr *place;
255
    int ret;
256

257
    if (VIR_ALLOC(filter) < 0)
258
        return -1;
259

260
    virObjectLock(client);
261

262 263 264 265
    filter->id = client->nextFilterID++;
    filter->func = func;
    filter->opaque = opaque;

266 267 268 269
    place = &client->filters;
    while (*place)
        place = &(*place)->next;
    *place = filter;
270 271 272

    ret = filter->id;

273
    virObjectUnlock(client);
274

275 276 277
    return ret;
}

278 279
void virNetServerClientRemoveFilter(virNetServerClientPtr client,
                                    int filterID)
280 281 282
{
    virNetServerClientFilterPtr tmp, prev;

283
    virObjectLock(client);
284

285 286 287 288 289 290 291 292 293 294 295 296
    prev = NULL;
    tmp = client->filters;
    while (tmp) {
        if (tmp->id == filterID) {
            if (prev)
                prev->next = tmp->next;
            else
                client->filters = tmp->next;

            VIR_FREE(tmp);
            break;
        }
E
Eric Blake 已提交
297
        prev = tmp;
298 299 300
        tmp = tmp->next;
    }

301
    virObjectUnlock(client);
302 303 304
}


305
#ifdef WITH_GNUTLS
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
/* Check the client's access. */
static int
virNetServerClientCheckAccess(virNetServerClientPtr client)
{
    virNetMessagePtr confirm;

    /* Verify client certificate. */
    if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0)
        return -1;

    if (client->tx) {
        VIR_DEBUG("client had unexpected data pending tx after access check");
        return -1;
    }

321
    if (!(confirm = virNetMessageNew(false)))
322 323 324 325 326 327 328
        return -1;

    /* Checks have succeeded.  Write a '\1' byte back to the client to
     * indicate this (otherwise the socket is abruptly closed).
     * (NB. The '\1' byte is sent in an encrypted record).
     */
    confirm->bufferLength = 1;
329 330 331 332
    if (VIR_ALLOC_N(confirm->buffer, confirm->bufferLength) < 0) {
        virNetMessageFree(confirm);
        return -1;
    }
333 334 335 336 337 338 339
    confirm->bufferOffset = 0;
    confirm->buffer[0] = '\1';

    client->tx = confirm;

    return 0;
}
340 341
#endif

342

M
Michal Privoznik 已提交
343 344 345 346
static void virNetServerClientSockTimerFunc(int timer,
                                            void *opaque)
{
    virNetServerClientPtr client = opaque;
347
    virObjectLock(client);
M
Michal Privoznik 已提交
348 349 350 351 352
    virEventUpdateTimeout(timer, -1);
    /* Although client->rx != NULL when this timer is enabled, it might have
     * changed since the client was unlocked in the meantime. */
    if (client->rx)
        virNetServerClientDispatchRead(client);
353
    virObjectUnlock(client);
M
Michal Privoznik 已提交
354 355
}

356

357
static virNetServerClientPtr
358 359
virNetServerClientNewInternal(unsigned long long id,
                              virNetSocketPtr sock,
360
                              int auth,
361
#ifdef WITH_GNUTLS
362 363
                              virNetTLSContextPtr tls,
#endif
364
                              bool readonly,
365 366
                              size_t nrequests_max,
                              time_t timestamp)
367 368 369
{
    virNetServerClientPtr client;

370
    if (virNetServerClientInitialize() < 0)
371 372
        return NULL;

373
    if (!(client = virObjectLockableNew(virNetServerClientClass)))
374 375
        return NULL;

376
    client->id = id;
377
    client->sock = virObjectRef(sock);
378 379
    client->auth = auth;
    client->readonly = readonly;
380
#ifdef WITH_GNUTLS
381
    client->tlsCtxt = virObjectRef(tls);
382
#endif
383
    client->nrequests_max = nrequests_max;
384
    client->conn_time = timestamp;
385

M
Michal Privoznik 已提交
386 387 388 389 390
    client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
                                           client, NULL);
    if (client->sockTimer < 0)
        goto error;

391
    /* Prepare one for packet receive */
392
    if (!(client->rx = virNetMessageNew(true)))
393 394
        goto error;
    client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
395
    if (VIR_ALLOC_N(client->rx->buffer, client->rx->bufferLength) < 0)
396
        goto error;
397 398
    client->nrequests = 1;

399
    PROBE(RPC_SERVER_CLIENT_NEW,
400 401
          "client=%p sock=%p",
          client, client->sock);
402 403 404

    return client;

405
 error:
406
    virObjectUnref(client);
407 408 409 410
    return NULL;
}


411 412
virNetServerClientPtr virNetServerClientNew(unsigned long long id,
                                            virNetSocketPtr sock,
413 414 415
                                            int auth,
                                            bool readonly,
                                            size_t nrequests_max,
416
#ifdef WITH_GNUTLS
417
                                            virNetTLSContextPtr tls,
418
#endif
419
                                            virNetServerClientPrivNew privNew,
420
                                            virNetServerClientPrivPreExecRestart privPreExecRestart,
421 422 423 424
                                            virFreeCallback privFree,
                                            void *privOpaque)
{
    virNetServerClientPtr client;
425
    time_t now;
426

427
    VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth,
428
#ifdef WITH_GNUTLS
429 430 431 432 433
              tls
#else
              NULL
#endif
        );
434

435 436 437 438 439
    if ((now = time(NULL)) == (time_t) - 1) {
        virReportSystemError(errno, "%s", _("failed to get current time"));
        return NULL;
    }

440
    if (!(client = virNetServerClientNewInternal(id, sock, auth,
441
#ifdef WITH_GNUTLS
442 443
                                                 tls,
#endif
444 445
                                                 readonly, nrequests_max,
                                                 now)))
446 447 448 449 450 451 452 453
        return NULL;

    if (privNew) {
        if (!(client->privateData = privNew(client, privOpaque))) {
            virObjectUnref(client);
            return NULL;
        }
        client->privateDataFreeFunc = privFree;
454
        client->privateDataPreExecRestart = privPreExecRestart;
455 456 457 458 459 460
    }

    return client;
}


461 462 463 464
virNetServerClientPtr virNetServerClientNewPostExecRestart(virJSONValuePtr object,
                                                           virNetServerClientPrivNewPostExecRestart privNew,
                                                           virNetServerClientPrivPreExecRestart privPreExecRestart,
                                                           virFreeCallback privFree,
465 466
                                                           void *privOpaque,
                                                           void *opaque)
467 468 469 470 471 472 473
{
    virJSONValuePtr child;
    virNetServerClientPtr client = NULL;
    virNetSocketPtr sock;
    int auth;
    bool readonly;
    unsigned int nrequests_max;
474
    unsigned long long id;
475
    time_t timestamp;
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499

    if (virJSONValueObjectGetNumberInt(object, "auth", &auth) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing auth field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetBoolean(object, "readonly", &readonly) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing readonly field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetNumberUint(object, "nrequests_max",
                                        (unsigned int *)&nrequests_max) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing nrequests_client_max field in JSON state document"));
        return NULL;
    }

    if (!(child = virJSONValueObjectGet(object, "sock"))) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing sock field in JSON state document"));
        return NULL;
    }

500 501 502 503 504 505 506 507 508 509 510 511
    if (!virJSONValueObjectHasKey(object, "id")) {
        /* no ID found in, a new one must be generated */
        id = virNetServerNextClientID((virNetServerPtr) opaque);
    } else {
        if (virJSONValueObjectGetNumberUlong(object, "id",
                                        (unsigned long long *) &id) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Malformed id field in JSON state document"));
        return NULL;
        }
    }

512 513 514 515 516 517 518 519 520 521 522 523
    if (!virJSONValueObjectHasKey(object, "conn_time")) {
        timestamp = 0;
    } else {
        if (virJSONValueObjectGetNumberLong(object, "conn_time",
                                            (long long *) &timestamp) < 0) {
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Malformed conn_time field in JSON "
                             "state document"));
            return NULL;
        }
    }

524 525 526 527 528
    if (!(sock = virNetSocketNewPostExecRestart(child))) {
        virObjectUnref(sock);
        return NULL;
    }

529 530
    if (!(client = virNetServerClientNewInternal(id,
                                                 sock,
531
                                                 auth,
532
#ifdef WITH_GNUTLS
533 534
                                                 NULL,
#endif
535
                                                 readonly,
536 537
                                                 nrequests_max,
                                                 timestamp))) {
538 539 540 541 542 543 544 545 546 547 548
        virObjectUnref(sock);
        return NULL;
    }
    virObjectUnref(sock);

    if (privNew) {
        if (!(child = virJSONValueObjectGet(object, "privateData"))) {
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Missing privateData field in JSON state document"));
            goto error;
        }
549
        if (!(client->privateData = privNew(client, child, privOpaque)))
550 551 552 553 554 555 556 557
            goto error;
        client->privateDataFreeFunc = privFree;
        client->privateDataPreExecRestart = privPreExecRestart;
    }


    return client;

558
 error:
559 560 561 562 563 564 565 566 567 568 569 570 571
    virObjectUnref(client);
    return NULL;
}


virJSONValuePtr virNetServerClientPreExecRestart(virNetServerClientPtr client)
{
    virJSONValuePtr object = virJSONValueNewObject();
    virJSONValuePtr child;

    if (!object)
        return NULL;

572
    virObjectLock(client);
573

574 575 576 577
    if (virJSONValueObjectAppendNumberUlong(object, "id",
                                            client->id) < 0)
        goto error;

578 579 580 581 582 583 584
    if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) < 0)
        goto error;
    if (virJSONValueObjectAppendBoolean(object, "readonly", client->readonly) < 0)
        goto error;
    if (virJSONValueObjectAppendNumberUint(object, "nrequests_max", client->nrequests_max) < 0)
        goto error;

585 586 587 588 589
    if (client->conn_time &&
        virJSONValueObjectAppendNumberLong(object, "conn_time",
                                           client->conn_time) < 0)
        goto error;

590 591 592 593 594 595 596 597
    if (!(child = virNetSocketPreExecRestart(client->sock)))
        goto error;

    if (virJSONValueObjectAppend(object, "sock", child) < 0) {
        virJSONValueFree(child);
        goto error;
    }

598 599 600
    if (client->privateData && client->privateDataPreExecRestart) {
        if (!(child = client->privateDataPreExecRestart(client, client->privateData)))
            goto error;
601

602 603 604 605
        if (virJSONValueObjectAppend(object, "privateData", child) < 0) {
            virJSONValueFree(child);
            goto error;
        }
606 607
    }

608
    virObjectUnlock(client);
609 610
    return object;

611
 error:
612
    virObjectUnlock(client);
613 614 615 616 617
    virJSONValueFree(object);
    return NULL;
}


618 619 620
int virNetServerClientGetAuth(virNetServerClientPtr client)
{
    int auth;
621
    virObjectLock(client);
622
    auth = client->auth;
623
    virObjectUnlock(client);
624 625 626
    return auth;
}

627 628 629 630 631 632 633
void virNetServerClientSetAuth(virNetServerClientPtr client, int auth)
{
    virObjectLock(client);
    client->auth = auth;
    virObjectUnlock(client);
}

634 635 636
bool virNetServerClientGetReadonly(virNetServerClientPtr client)
{
    bool readonly;
637
    virObjectLock(client);
638
    readonly = client->readonly;
639
    virObjectUnlock(client);
640 641 642
    return readonly;
}

643 644 645 646
unsigned long long virNetServerClientGetID(virNetServerClientPtr client)
{
    return client->id;
}
647

648 649 650 651 652
long long virNetServerClientGetTimestamp(virNetServerClientPtr client)
{
    return client->conn_time;
}

653
#ifdef WITH_GNUTLS
654 655 656
bool virNetServerClientHasTLSSession(virNetServerClientPtr client)
{
    bool has;
657
    virObjectLock(client);
658
    has = client->tls ? true : false;
659
    virObjectUnlock(client);
660 661 662
    return has;
}

663 664 665 666 667 668 669 670 671 672

virNetTLSSessionPtr virNetServerClientGetTLSSession(virNetServerClientPtr client)
{
    virNetTLSSessionPtr tls;
    virObjectLock(client);
    tls = client->tls;
    virObjectUnlock(client);
    return tls;
}

673 674 675
int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
{
    int size = 0;
676
    virObjectLock(client);
677 678
    if (client->tls)
        size = virNetTLSSessionGetKeySize(client->tls);
679
    virObjectUnlock(client);
680 681
    return size;
}
682
#endif
683 684 685

int virNetServerClientGetFD(virNetServerClientPtr client)
{
686
    int fd = -1;
687
    virObjectLock(client);
688 689
    if (client->sock)
        fd = virNetSocketGetFD(client->sock);
690
    virObjectUnlock(client);
691 692 693
    return fd;
}

694 695 696 697 698 699 700 701 702 703 704 705

bool virNetServerClientIsLocal(virNetServerClientPtr client)
{
    bool local = false;
    virObjectLock(client);
    if (client->sock)
        local = virNetSocketIsLocal(client->sock);
    virObjectUnlock(client);
    return local;
}


706
int virNetServerClientGetUNIXIdentity(virNetServerClientPtr client,
707 708
                                      uid_t *uid, gid_t *gid, pid_t *pid,
                                      unsigned long long *timestamp)
709
{
710
    int ret = -1;
711
    virObjectLock(client);
712
    if (client->sock)
713 714 715
        ret = virNetSocketGetUNIXIdentity(client->sock,
                                          uid, gid, pid,
                                          timestamp);
716
    virObjectUnlock(client);
717 718 719
    return ret;
}

720

721 722 723 724 725 726 727 728
static virIdentityPtr
virNetServerClientCreateIdentity(virNetServerClientPtr client)
{
    char *username = NULL;
    char *groupname = NULL;
    char *seccontext = NULL;
    virIdentityPtr ret = NULL;

729 730 731
    if (!(ret = virIdentityNew()))
        goto error;

732 733 734 735
    if (client->sock && virNetSocketIsLocal(client->sock)) {
        gid_t gid;
        uid_t uid;
        pid_t pid;
736 737 738 739
        unsigned long long timestamp;
        if (virNetSocketGetUNIXIdentity(client->sock,
                                        &uid, &gid, &pid,
                                        &timestamp) < 0)
740
            goto error;
741 742

        if (!(username = virGetUserName(uid)))
743 744 745 746 747 748
            goto error;
        if (virIdentitySetUNIXUserName(ret, username) < 0)
            goto error;
        if (virIdentitySetUNIXUserID(ret, uid) < 0)
            goto error;

749
        if (!(groupname = virGetGroupName(gid)))
750 751 752 753 754 755 756 757 758 759
            goto error;
        if (virIdentitySetUNIXGroupName(ret, groupname) < 0)
            goto error;
        if (virIdentitySetUNIXGroupID(ret, gid) < 0)
            goto error;

        if (virIdentitySetUNIXProcessID(ret, pid) < 0)
            goto error;
        if (virIdentitySetUNIXProcessTime(ret, timestamp) < 0)
            goto error;
760 761 762 763 764
    }

#if WITH_SASL
    if (client->sasl) {
        const char *identity = virNetSASLSessionGetIdentity(client->sasl);
765 766
        if (virIdentitySetSASLUserName(ret, identity) < 0)
            goto error;
767 768 769
    }
#endif

770
#if WITH_GNUTLS
771 772
    if (client->tls) {
        const char *identity = virNetTLSSessionGetX509DName(client->tls);
773 774
        if (virIdentitySetX509DName(ret, identity) < 0)
            goto error;
775
    }
776
#endif
777 778

    if (client->sock &&
779
        virNetSocketGetSELinuxContext(client->sock, &seccontext) < 0)
780
        goto error;
781
    if (seccontext &&
782
        virIdentitySetSELinuxContext(ret, seccontext) < 0)
783 784
        goto error;

785
 cleanup:
786 787 788 789 790
    VIR_FREE(username);
    VIR_FREE(groupname);
    VIR_FREE(seccontext);
    return ret;

791
 error:
792
    virObjectUnref(ret);
793
    ret = NULL;
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
    goto cleanup;
}


virIdentityPtr virNetServerClientGetIdentity(virNetServerClientPtr client)
{
    virIdentityPtr ret = NULL;
    virObjectLock(client);
    if (!client->identity)
        client->identity = virNetServerClientCreateIdentity(client);
    if (client->identity)
        ret = virObjectRef(client->identity);
    virObjectUnlock(client);
    return ret;
}


811 812
int virNetServerClientGetSELinuxContext(virNetServerClientPtr client,
                                        char **context)
813 814 815 816 817
{
    int ret = 0;
    *context = NULL;
    virObjectLock(client);
    if (client->sock)
818
        ret = virNetSocketGetSELinuxContext(client->sock, context);
819 820 821 822 823
    virObjectUnlock(client);
    return ret;
}


824 825 826
bool virNetServerClientIsSecure(virNetServerClientPtr client)
{
    bool secure = false;
827
    virObjectLock(client);
828
#if WITH_GNUTLS
829 830
    if (client->tls)
        secure = true;
831
#endif
832
#if WITH_SASL
833 834 835
    if (client->sasl)
        secure = true;
#endif
836
    if (client->sock && virNetSocketIsLocal(client->sock))
837
        secure = true;
838
    virObjectUnlock(client);
839 840 841 842
    return secure;
}


843
#if WITH_SASL
844 845 846 847 848 849 850 851
void virNetServerClientSetSASLSession(virNetServerClientPtr client,
                                      virNetSASLSessionPtr sasl)
{
    /* We don't set the sasl session on the socket here
     * because we need to send out the auth confirmation
     * in the clear. Only once we complete the next 'tx'
     * operation do we switch to SASL mode
     */
852
    virObjectLock(client);
853
    client->sasl = virObjectRef(sasl);
854
    virObjectUnlock(client);
855
}
856 857 858 859 860 861 862 863 864 865


virNetSASLSessionPtr virNetServerClientGetSASLSession(virNetServerClientPtr client)
{
    virNetSASLSessionPtr sasl;
    virObjectLock(client);
    sasl = client->sasl;
    virObjectUnlock(client);
    return sasl;
}
866 867 868 869 870 871
#endif


void *virNetServerClientGetPrivateData(virNetServerClientPtr client)
{
    void *data;
872
    virObjectLock(client);
873
    data = client->privateData;
874
    virObjectUnlock(client);
875 876 877 878
    return data;
}


879 880 881
void virNetServerClientSetCloseHook(virNetServerClientPtr client,
                                    virNetServerClientCloseFunc cf)
{
882
    virObjectLock(client);
883
    client->privateDataCloseFunc = cf;
884
    virObjectUnlock(client);
885 886 887
}


888 889 890 891
void virNetServerClientSetDispatcher(virNetServerClientPtr client,
                                     virNetServerClientDispatchFunc func,
                                     void *opaque)
{
892
    virObjectLock(client);
893 894
    client->dispatchFunc = func;
    client->dispatchOpaque = opaque;
895
    virObjectUnlock(client);
896 897 898 899 900
}


const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
{
901 902
    if (!client->sock)
        return NULL;
903 904 905 906 907 908
    return virNetSocketLocalAddrString(client->sock);
}


const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
{
909 910
    if (!client->sock)
        return NULL;
911 912 913 914
    return virNetSocketRemoteAddrString(client->sock);
}


915
void virNetServerClientDispose(void *obj)
916
{
917
    virNetServerClientPtr client = obj;
918

919 920 921
    PROBE(RPC_SERVER_CLIENT_DISPOSE,
          "client=%p", client);

922 923 924 925
    if (client->privateData &&
        client->privateDataFreeFunc)
        client->privateDataFreeFunc(client->privateData);

926 927
    virObjectUnref(client->identity);

928
#if WITH_SASL
929
    virObjectUnref(client->sasl);
930
#endif
M
Michal Privoznik 已提交
931 932
    if (client->sockTimer > 0)
        virEventRemoveTimeout(client->sockTimer);
933
#if WITH_GNUTLS
934 935
    virObjectUnref(client->tls);
    virObjectUnref(client->tlsCtxt);
936
#endif
937
    virObjectUnref(client->sock);
938 939 940 941 942 943 944 945 946 947 948 949 950
}


/*
 *
 * We don't free stuff here, merely disconnect the client's
 * network socket & resources.
 *
 * Full free of the client is done later in a safe point
 * where it can be guaranteed it is no longer in use
 */
void virNetServerClientClose(virNetServerClientPtr client)
{
951
    virNetServerClientCloseFunc cf;
952
    virKeepAlivePtr ka;
953

954
    virObjectLock(client);
955
    VIR_DEBUG("client=%p", client);
956
    if (!client->sock) {
957
        virObjectUnlock(client);
958 959 960
        return;
    }

961 962 963 964
    if (client->keepalive) {
        virKeepAliveStop(client->keepalive);
        ka = client->keepalive;
        client->keepalive = NULL;
965
        virObjectRef(client);
966
        virObjectUnlock(client);
967
        virObjectUnref(ka);
968
        virObjectLock(client);
969
        virObjectUnref(client);
970 971
    }

972 973
    if (client->privateDataCloseFunc) {
        cf = client->privateDataCloseFunc;
974
        virObjectRef(client);
975
        virObjectUnlock(client);
976
        (cf)(client);
977
        virObjectLock(client);
978
        virObjectUnref(client);
979 980
    }

981 982 983 984 985 986
    /* Do now, even though we don't close the socket
     * until end, to ensure we don't get invoked
     * again due to tls shutdown */
    if (client->sock)
        virNetSocketRemoveIOCallback(client->sock);

987
#if WITH_GNUTLS
988
    if (client->tls) {
989
        virObjectUnref(client->tls);
990 991
        client->tls = NULL;
    }
992
#endif
993
    client->wantClose = true;
994 995 996 997 998 999 1000 1001 1002 1003 1004 1005

    while (client->rx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->rx);
        virNetMessageFree(msg);
    }
    while (client->tx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->tx);
        virNetMessageFree(msg);
    }

1006
    if (client->sock) {
1007
        virObjectUnref(client->sock);
1008 1009 1010
        client->sock = NULL;
    }

1011
    virObjectUnlock(client);
1012 1013 1014 1015 1016 1017
}


bool virNetServerClientIsClosed(virNetServerClientPtr client)
{
    bool closed;
1018
    virObjectLock(client);
1019
    closed = client->sock == NULL ? true : false;
1020
    virObjectUnlock(client);
1021 1022 1023
    return closed;
}

1024 1025
void virNetServerClientDelayedClose(virNetServerClientPtr client)
{
1026
    virObjectLock(client);
1027
    client->delayedClose = true;
1028
    virObjectUnlock(client);
1029 1030 1031
}

void virNetServerClientImmediateClose(virNetServerClientPtr client)
1032
{
1033
    virObjectLock(client);
1034
    client->wantClose = true;
1035
    virObjectUnlock(client);
1036 1037 1038 1039 1040
}

bool virNetServerClientWantClose(virNetServerClientPtr client)
{
    bool wantClose;
1041
    virObjectLock(client);
1042
    wantClose = client->wantClose;
1043
    virObjectUnlock(client);
1044 1045 1046 1047 1048 1049
    return wantClose;
}


int virNetServerClientInit(virNetServerClientPtr client)
{
1050
    virObjectLock(client);
1051

1052
#if WITH_GNUTLS
1053
    if (!client->tlsCtxt) {
1054
#endif
1055 1056 1057
        /* Plain socket, so prepare to read first message */
        if (virNetServerClientRegisterEvent(client) < 0)
            goto error;
1058
#if WITH_GNUTLS
1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
    } else {
        int ret;

        if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
                                                NULL)))
            goto error;

        virNetSocketSetTLSSession(client->sock,
                                  client->tls);

        /* Begin the TLS handshake. */
        ret = virNetTLSSessionHandshake(client->tls);
        if (ret == 0) {
            /* Unlikely, but ...  Next step is to check the certificate. */
            if (virNetServerClientCheckAccess(client) < 0)
                goto error;

            /* Handshake & cert check OK,  so prepare to read first message */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else if (ret > 0) {
            /* Most likely, need to do more handshake data */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else {
            goto error;
        }
    }
1087
#endif
1088

1089
    virObjectUnlock(client);
1090 1091
    return 0;

1092
 error:
1093
    client->wantClose = true;
1094
    virObjectUnlock(client);
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112
    return -1;
}



/*
 * Read data into buffer using wire decoding (plain or TLS)
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientRead(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->rx->bufferLength <= client->rx->bufferOffset) {
1113 1114 1115
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->rx->bufferLength - client->rx->bufferOffset));
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136
        client->wantClose = true;
        return -1;
    }

    ret = virNetSocketRead(client->sock,
                           client->rx->buffer + client->rx->bufferOffset,
                           client->rx->bufferLength - client->rx->bufferOffset);

    if (ret <= 0)
        return ret;

    client->rx->bufferOffset += ret;
    return ret;
}


/*
 * Read data until we get a complete message to process
 */
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
1137
 readmore:
1138 1139 1140 1141 1142
    if (client->rx->nfds == 0) {
        if (virNetServerClientRead(client) < 0) {
            client->wantClose = true;
            return; /* Error */
        }
1143 1144 1145 1146 1147 1148 1149
    }

    if (client->rx->bufferOffset < client->rx->bufferLength)
        return; /* Still not read enough */

    /* Either done with length word header */
    if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) {
1150 1151
        if (virNetMessageDecodeLength(client->rx) < 0) {
            client->wantClose = true;
1152
            return;
1153
        }
1154 1155 1156 1157 1158 1159 1160 1161 1162

        virNetServerClientUpdateEvent(client);

        /* Try and read payload immediately instead of going back
           into poll() because chances are the data is already
           waiting for us */
        goto readmore;
    } else {
        /* Grab the completed message */
1163
        virNetMessagePtr msg = client->rx;
1164
        virNetMessagePtr response = NULL;
1165
        virNetServerClientFilterPtr filter;
1166
        size_t i;
1167 1168 1169

        /* Decode the header so we can use it for routing decisions */
        if (virNetMessageDecodeHeader(msg) < 0) {
1170
            virNetMessageQueueServe(&client->rx);
1171 1172 1173 1174 1175
            virNetMessageFree(msg);
            client->wantClose = true;
            return;
        }

1176 1177
        /* Now figure out if we need to read more data to get some
         * file descriptors */
1178 1179 1180
        if (msg->header.type == VIR_NET_CALL_WITH_FDS) {
            if (msg->nfds == 0 &&
                virNetMessageDecodeNumFDs(msg) < 0) {
1181
                virNetMessageQueueServe(&client->rx);
1182 1183
                virNetMessageFree(msg);
                client->wantClose = true;
1184
                return; /* Error */
1185
            }
1186

1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
            /* Try getting the file descriptors (may fail if blocking) */
            for (i = msg->donefds; i < msg->nfds; i++) {
                int rv;
                if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
                    virNetMessageQueueServe(&client->rx);
                    virNetMessageFree(msg);
                    client->wantClose = true;
                    return;
                }
                if (rv == 0) /* Blocking */
                    break;
                msg->donefds++;
            }

            /* Need to poll() until FDs arrive */
            if (msg->donefds < msg->nfds) {
                /* Because DecodeHeader/NumFDs reset bufferOffset, we
                 * put it back to what it was, so everything works
                 * again next time we run this method
                 */
                client->rx->bufferOffset = client->rx->bufferLength;
                return;
            }
1210 1211
        }

1212 1213
        /* Definitely finished reading, so remove from queue */
        virNetMessageQueueServe(&client->rx);
1214 1215 1216 1217 1218 1219
        PROBE(RPC_SERVER_CLIENT_MSG_RX,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);

1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
        if (virKeepAliveCheckMessage(client->keepalive, msg, &response)) {
            virNetMessageFree(msg);
            client->nrequests--;
            msg = NULL;

            if (response &&
                virNetServerClientSendMessageLocked(client, response) < 0)
                virNetMessageFree(response);
        }

1230
        /* Maybe send off for queue against a filter */
1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
        if (msg) {
            filter = client->filters;
            while (filter) {
                int ret = filter->func(client, msg, filter->opaque);
                if (ret < 0) {
                    virNetMessageFree(msg);
                    msg = NULL;
                    if (ret < 0)
                        client->wantClose = true;
                    break;
                }
                if (ret > 0) {
                    msg = NULL;
                    break;
                }
1246

1247 1248
                filter = filter->next;
            }
1249 1250 1251 1252
        }

        /* Send off to for normal dispatch to workers */
        if (msg) {
1253
            virObjectRef(client);
1254 1255 1256 1257
            if (!client->dispatchFunc ||
                client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) {
                virNetMessageFree(msg);
                client->wantClose = true;
1258
                virObjectUnref(client);
1259 1260 1261 1262 1263 1264
                return;
            }
        }

        /* Possibly need to create another receive buffer */
        if (client->nrequests < client->nrequests_max) {
1265
            if (!(client->rx = virNetMessageNew(true))) {
1266
                client->wantClose = true;
E
Eric Blake 已提交
1267 1268
            } else {
                client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
1269 1270 1271 1272 1273 1274
                if (VIR_ALLOC_N(client->rx->buffer,
                                client->rx->bufferLength) < 0) {
                    client->wantClose = true;
                } else {
                    client->nrequests++;
                }
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294
            }
        }
        virNetServerClientUpdateEvent(client);
    }
}


/*
 * Send client->tx using no encoding
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientWrite(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->tx->bufferLength < client->tx->bufferOffset) {
1295 1296 1297
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->tx->bufferLength - client->tx->bufferOffset));
1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
        client->wantClose = true;
        return -1;
    }

    if (client->tx->bufferLength == client->tx->bufferOffset)
        return 1;

    ret = virNetSocketWrite(client->sock,
                            client->tx->buffer + client->tx->bufferOffset,
                            client->tx->bufferLength - client->tx->bufferOffset);
    if (ret <= 0)
        return ret; /* -1 error, 0 = egain */

    client->tx->bufferOffset += ret;
    return ret;
}


/*
 * Process all queued client->tx messages until
 * we would block on I/O
 */
static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
    while (client->tx) {
1324 1325 1326 1327 1328 1329 1330 1331 1332
        if (client->tx->bufferOffset < client->tx->bufferLength) {
            ssize_t ret;
            ret = virNetServerClientWrite(client);
            if (ret < 0) {
                client->wantClose = true;
                return;
            }
            if (ret == 0)
                return; /* Would block on write EAGAIN */
1333 1334 1335 1336
        }

        if (client->tx->bufferOffset == client->tx->bufferLength) {
            virNetMessagePtr msg;
1337 1338
            size_t i;

1339
            for (i = client->tx->donefds; i < client->tx->nfds; i++) {
1340 1341
                int rv;
                if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
1342 1343 1344
                    client->wantClose = true;
                    return;
                }
1345 1346 1347
                if (rv == 0) /* Blocking */
                    return;
                client->tx->donefds++;
1348 1349
            }

1350
#if WITH_SASL
1351 1352 1353 1354 1355
            /* Completed this 'tx' operation, so now read for all
             * future rx/tx to be under a SASL SSF layer
             */
            if (client->sasl) {
                virNetSocketSetSASLSession(client->sock, client->sasl);
1356
                virObjectUnref(client->sasl);
1357 1358 1359 1360 1361 1362 1363
                client->sasl = NULL;
            }
#endif

            /* Get finished msg from head of tx queue */
            msg = virNetMessageQueueServe(&client->tx);

1364
            if (msg->tracked) {
1365 1366 1367 1368 1369
                client->nrequests--;
                /* See if the recv queue is currently throttled */
                if (!client->rx &&
                    client->nrequests < client->nrequests_max) {
                    /* Ready to recv more messages */
1370
                    virNetMessageClear(msg);
1371 1372 1373 1374 1375
                    msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
                    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
                        virNetMessageFree(msg);
                        return;
                    }
1376 1377 1378 1379 1380 1381 1382 1383 1384
                    client->rx = msg;
                    msg = NULL;
                    client->nrequests++;
                }
            }

            virNetMessageFree(msg);

            virNetServerClientUpdateEvent(client);
1385 1386 1387

            if (client->delayedClose)
                client->wantClose = true;
1388 1389 1390 1391
         }
    }
}

1392

1393
#if WITH_GNUTLS
1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409
static void
virNetServerClientDispatchHandshake(virNetServerClientPtr client)
{
    int ret;
    /* Continue the handshake. */
    ret = virNetTLSSessionHandshake(client->tls);
    if (ret == 0) {
        /* Finished.  Next step is to check the certificate. */
        if (virNetServerClientCheckAccess(client) < 0)
            client->wantClose = true;
        else
            virNetServerClientUpdateEvent(client);
    } else if (ret > 0) {
        /* Carry on waiting for more handshake. Update
           the events just in case handshake data flow
           direction has changed */
1410
        virNetServerClientUpdateEvent(client);
1411 1412 1413 1414 1415
    } else {
        /* Fatal error in handshake */
        client->wantClose = true;
    }
}
1416
#endif
1417 1418 1419 1420 1421 1422

static void
virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque)
{
    virNetServerClientPtr client = opaque;

1423
    virObjectLock(client);
1424 1425 1426

    if (client->sock != sock) {
        virNetSocketRemoveIOCallback(sock);
1427
        virObjectUnlock(client);
1428 1429 1430 1431 1432
        return;
    }

    if (events & (VIR_EVENT_HANDLE_WRITABLE |
                  VIR_EVENT_HANDLE_READABLE)) {
1433
#if WITH_GNUTLS
1434 1435 1436 1437 1438
        if (client->tls &&
            virNetTLSSessionGetHandshakeStatus(client->tls) !=
            VIR_NET_TLS_HANDSHAKE_COMPLETE) {
            virNetServerClientDispatchHandshake(client);
        } else {
1439
#endif
1440 1441
            if (events & VIR_EVENT_HANDLE_WRITABLE)
                virNetServerClientDispatchWrite(client);
M
Michal Privoznik 已提交
1442 1443
            if (events & VIR_EVENT_HANDLE_READABLE &&
                client->rx)
1444
                virNetServerClientDispatchRead(client);
1445
#if WITH_GNUTLS
1446
        }
1447
#endif
1448 1449 1450 1451 1452 1453 1454 1455
    }

    /* NB, will get HANGUP + READABLE at same time upon
     * disconnect */
    if (events & (VIR_EVENT_HANDLE_ERROR |
                  VIR_EVENT_HANDLE_HANGUP))
        client->wantClose = true;

1456
    virObjectUnlock(client);
1457 1458 1459
}


1460 1461 1462
static int
virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                    virNetMessagePtr msg)
1463 1464 1465 1466 1467
{
    int ret = -1;
    VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu",
              msg, msg->header.proc,
              msg->bufferLength, msg->bufferOffset);
1468

1469
    msg->donefds = 0;
1470
    if (client->sock && !client->wantClose) {
1471 1472 1473 1474 1475
        PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);
1476 1477 1478 1479 1480 1481
        virNetMessageQueuePush(&client->tx, msg);

        virNetServerClientUpdateEvent(client);
        ret = 0;
    }

1482 1483 1484 1485 1486 1487 1488 1489
    return ret;
}

int virNetServerClientSendMessage(virNetServerClientPtr client,
                                  virNetMessagePtr msg)
{
    int ret;

1490
    virObjectLock(client);
1491
    ret = virNetServerClientSendMessageLocked(client, msg);
1492
    virObjectUnlock(client);
1493

1494 1495 1496 1497 1498 1499 1500
    return ret;
}


bool virNetServerClientNeedAuth(virNetServerClientPtr client)
{
    bool need = false;
1501
    virObjectLock(client);
1502
    if (client->auth)
1503
        need = true;
1504
    virObjectUnlock(client);
1505 1506
    return need;
}
1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530


static void
virNetServerClientKeepAliveDeadCB(void *opaque)
{
    virNetServerClientImmediateClose(opaque);
}

static int
virNetServerClientKeepAliveSendCB(void *opaque,
                                  virNetMessagePtr msg)
{
    return virNetServerClientSendMessage(opaque, msg);
}


int
virNetServerClientInitKeepAlive(virNetServerClientPtr client,
                                int interval,
                                unsigned int count)
{
    virKeepAlivePtr ka;
    int ret = -1;

1531
    virObjectLock(client);
1532 1533 1534 1535

    if (!(ka = virKeepAliveNew(interval, count, client,
                               virNetServerClientKeepAliveSendCB,
                               virNetServerClientKeepAliveDeadCB,
1536
                               virObjectFreeCallback)))
1537 1538
        goto cleanup;
    /* keepalive object has a reference to client */
1539
    virObjectRef(client);
1540 1541 1542

    client->keepalive = ka;

1543
 cleanup:
1544
    virObjectUnlock(client);
1545 1546 1547 1548 1549 1550 1551

    return ret;
}

int
virNetServerClientStartKeepAlive(virNetServerClientPtr client)
{
1552 1553
    int ret = -1;

1554
    virObjectLock(client);
1555 1556 1557 1558

    /* The connection might have been closed before we got here and thus the
     * keepalive object could have been removed too.
     */
1559
    if (!client->keepalive) {
1560 1561 1562 1563 1564
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("connection not open"));
        goto cleanup;
    }

1565
    ret = virKeepAliveStart(client->keepalive, 0, 0);
1566

1567
 cleanup:
1568
    virObjectUnlock(client);
1569 1570
    return ret;
}