virnetserverclient.c 41.7 KB
Newer Older
1 2 3
/*
 * virnetserverclient.c: generic network RPC server client
 *
4
 * Copyright (C) 2006-2014 Red Hat, Inc.
5 6 7 8 9 10 11 12 13 14 15 16 17
 * Copyright (C) 2006 Daniel P. Berrange
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
18
 * License along with this library.  If not, see
O
Osier Yang 已提交
19
 * <http://www.gnu.org/licenses/>.
20 21 22 23 24 25
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */

#include <config.h>

26
#include "internal.h"
27
#if WITH_SASL
28 29 30
# include <sasl/sasl.h>
#endif

31
#include "virnetserver.h"
32 33
#include "virnetserverclient.h"

34
#include "virlog.h"
35
#include "virerror.h"
36
#include "viralloc.h"
37
#include "virthread.h"
38
#include "virkeepalive.h"
39
#include "virprobe.h"
40 41
#include "virstring.h"
#include "virutil.h"
42 43 44

#define VIR_FROM_THIS VIR_FROM_RPC

45 46
VIR_LOG_INIT("rpc.netserverclient");

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/* Allow for filtering of incoming messages to a custom
 * dispatch processing queue, instead of the workers.
 * This allows for certain types of messages to be handled
 * strictly "in order"
 */

typedef struct _virNetServerClientFilter virNetServerClientFilter;
typedef virNetServerClientFilter *virNetServerClientFilterPtr;

struct _virNetServerClientFilter {
    int id;
    virNetServerClientFilterFunc func;
    void *opaque;

    virNetServerClientFilterPtr next;
};


struct _virNetServerClient
{
67
    virObjectLockable parent;
68

69
    unsigned long long id;
70
    bool wantClose;
71
    bool delayedClose;
72 73 74
    virNetSocketPtr sock;
    int auth;
    bool readonly;
75
#if WITH_GNUTLS
76 77
    virNetTLSContextPtr tlsCtxt;
    virNetTLSSessionPtr tls;
78
#endif
79
#if WITH_SASL
80 81
    virNetSASLSessionPtr sasl;
#endif
M
Michal Privoznik 已提交
82 83
    int sockTimer; /* Timer to be fired upon cached data,
                    * so we jump out from poll() immediately */
84

85 86 87

    virIdentityPtr identity;

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    /* Count of messages in the 'tx' queue,
     * and the server worker pool queue
     * ie RPC calls in progress. Does not count
     * async events which are not used for
     * throttling calculations */
    size_t nrequests;
    size_t nrequests_max;
    /* Zero or one messages being received. Zero if
     * nrequests >= max_clients and throttling */
    virNetMessagePtr rx;
    /* Zero or many messages waiting for transmit
     * back to client, including async events */
    virNetMessagePtr tx;

    /* Filters to capture messages that would otherwise
     * end up on the 'dx' queue */
    virNetServerClientFilterPtr filters;
    int nextFilterID;

    virNetServerClientDispatchFunc dispatchFunc;
    void *dispatchOpaque;

    void *privateData;
111
    virFreeCallback privateDataFreeFunc;
112
    virNetServerClientPrivPreExecRestart privateDataPreExecRestart;
113
    virNetServerClientCloseFunc privateDataCloseFunc;
114 115

    virKeepAlivePtr keepalive;
116 117 118
};


119 120 121 122 123
static virClassPtr virNetServerClientClass;
static void virNetServerClientDispose(void *obj);

static int virNetServerClientOnceInit(void)
{
124
    if (!(virNetServerClientClass = virClassNew(virClassForObjectLockable(),
125
                                                "virNetServerClient",
126 127 128 129 130 131 132 133 134 135
                                                sizeof(virNetServerClient),
                                                virNetServerClientDispose)))
        return -1;

    return 0;
}

VIR_ONCE_GLOBAL_INIT(virNetServerClient)


136 137
static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque);
static void virNetServerClientUpdateEvent(virNetServerClientPtr client);
M
Michal Privoznik 已提交
138
static void virNetServerClientDispatchRead(virNetServerClientPtr client);
139 140
static int virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                               virNetMessagePtr msg);
141 142 143 144 145

/*
 * @client: a locked client object
 */
static int
146 147
virNetServerClientCalculateHandleMode(virNetServerClientPtr client)
{
148 149 150 151
    int mode = 0;


    VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p",
152
#ifdef WITH_GNUTLS
153 154
              client->tls,
              client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1,
155 156 157
#else
              NULL, -1,
#endif
158 159 160 161 162
              client->rx,
              client->tx);
    if (!client->sock || client->wantClose)
        return 0;

163
#if WITH_GNUTLS
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    if (client->tls) {
        switch (virNetTLSSessionGetHandshakeStatus(client->tls)) {
        case VIR_NET_TLS_HANDSHAKE_RECVING:
            mode |= VIR_EVENT_HANDLE_READABLE;
            break;
        case VIR_NET_TLS_HANDSHAKE_SENDING:
            mode |= VIR_EVENT_HANDLE_WRITABLE;
            break;
        default:
        case VIR_NET_TLS_HANDSHAKE_COMPLETE:
            if (client->rx)
                mode |= VIR_EVENT_HANDLE_READABLE;
            if (client->tx)
                mode |= VIR_EVENT_HANDLE_WRITABLE;
        }
    } else {
180
#endif
181 182
        /* If there is a message on the rx queue, and
         * we're not in middle of a delayedClose, then
183
         * we're wanting more input */
184
        if (client->rx && !client->delayedClose)
185 186 187 188 189 190
            mode |= VIR_EVENT_HANDLE_READABLE;

        /* If there are one or more messages to send back to client,
           then monitor for writability on socket */
        if (client->tx)
            mode |= VIR_EVENT_HANDLE_WRITABLE;
191
#if WITH_GNUTLS
192
    }
193
#endif
194
    VIR_DEBUG("mode=%o", mode);
195 196 197 198 199 200 201 202 203 204 205
    return mode;
}

/*
 * @server: a locked or unlocked server object
 * @client: a locked client object
 */
static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
{
    int mode = virNetServerClientCalculateHandleMode(client);

206 207 208 209
    if (!client->sock)
        return -1;

    virObjectRef(client);
210
    VIR_DEBUG("Registering client event callback %d", mode);
211
    if (virNetSocketAddIOCallback(client->sock,
212 213
                                  mode,
                                  virNetServerClientDispatchEvent,
214
                                  client,
215 216
                                  virObjectFreeCallback) < 0) {
        virObjectUnref(client);
217
        return -1;
218
    }
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235

    return 0;
}

/*
 * @client: a locked client object
 */
static void virNetServerClientUpdateEvent(virNetServerClientPtr client)
{
    int mode;

    if (!client->sock)
        return;

    mode = virNetServerClientCalculateHandleMode(client);

    virNetSocketUpdateIOCallback(client->sock, mode);
M
Michal Privoznik 已提交
236 237 238

    if (client->rx && virNetSocketHasCachedData(client->sock))
        virEventUpdateTimeout(client->sockTimer, 0);
239 240 241
}


242 243 244
int virNetServerClientAddFilter(virNetServerClientPtr client,
                                virNetServerClientFilterFunc func,
                                void *opaque)
245 246
{
    virNetServerClientFilterPtr filter;
247
    virNetServerClientFilterPtr *place;
248
    int ret;
249

250
    if (VIR_ALLOC(filter) < 0)
251
        return -1;
252

253
    virObjectLock(client);
254

255 256 257 258
    filter->id = client->nextFilterID++;
    filter->func = func;
    filter->opaque = opaque;

259 260 261 262
    place = &client->filters;
    while (*place)
        place = &(*place)->next;
    *place = filter;
263 264 265

    ret = filter->id;

266
    virObjectUnlock(client);
267

268 269 270
    return ret;
}

271 272
void virNetServerClientRemoveFilter(virNetServerClientPtr client,
                                    int filterID)
273 274 275
{
    virNetServerClientFilterPtr tmp, prev;

276
    virObjectLock(client);
277

278 279 280 281 282 283 284 285 286 287 288 289
    prev = NULL;
    tmp = client->filters;
    while (tmp) {
        if (tmp->id == filterID) {
            if (prev)
                prev->next = tmp->next;
            else
                client->filters = tmp->next;

            VIR_FREE(tmp);
            break;
        }
E
Eric Blake 已提交
290
        prev = tmp;
291 292 293
        tmp = tmp->next;
    }

294
    virObjectUnlock(client);
295 296 297
}


298
#ifdef WITH_GNUTLS
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
/* Check the client's access. */
static int
virNetServerClientCheckAccess(virNetServerClientPtr client)
{
    virNetMessagePtr confirm;

    /* Verify client certificate. */
    if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0)
        return -1;

    if (client->tx) {
        VIR_DEBUG("client had unexpected data pending tx after access check");
        return -1;
    }

314
    if (!(confirm = virNetMessageNew(false)))
315 316 317 318 319 320 321
        return -1;

    /* Checks have succeeded.  Write a '\1' byte back to the client to
     * indicate this (otherwise the socket is abruptly closed).
     * (NB. The '\1' byte is sent in an encrypted record).
     */
    confirm->bufferLength = 1;
322 323 324 325
    if (VIR_ALLOC_N(confirm->buffer, confirm->bufferLength) < 0) {
        virNetMessageFree(confirm);
        return -1;
    }
326 327 328 329 330 331 332
    confirm->bufferOffset = 0;
    confirm->buffer[0] = '\1';

    client->tx = confirm;

    return 0;
}
333 334
#endif

335

M
Michal Privoznik 已提交
336 337 338 339
static void virNetServerClientSockTimerFunc(int timer,
                                            void *opaque)
{
    virNetServerClientPtr client = opaque;
340
    virObjectLock(client);
M
Michal Privoznik 已提交
341 342 343 344 345
    virEventUpdateTimeout(timer, -1);
    /* Although client->rx != NULL when this timer is enabled, it might have
     * changed since the client was unlocked in the meantime. */
    if (client->rx)
        virNetServerClientDispatchRead(client);
346
    virObjectUnlock(client);
M
Michal Privoznik 已提交
347 348
}

349

350
static virNetServerClientPtr
351 352
virNetServerClientNewInternal(unsigned long long id,
                              virNetSocketPtr sock,
353
                              int auth,
354
#ifdef WITH_GNUTLS
355 356
                              virNetTLSContextPtr tls,
#endif
357
                              bool readonly,
358
                              size_t nrequests_max)
359 360 361
{
    virNetServerClientPtr client;

362
    if (virNetServerClientInitialize() < 0)
363 364
        return NULL;

365
    if (!(client = virObjectLockableNew(virNetServerClientClass)))
366 367
        return NULL;

368
    client->id = id;
369
    client->sock = virObjectRef(sock);
370 371
    client->auth = auth;
    client->readonly = readonly;
372
#ifdef WITH_GNUTLS
373
    client->tlsCtxt = virObjectRef(tls);
374
#endif
375
    client->nrequests_max = nrequests_max;
376

M
Michal Privoznik 已提交
377 378 379 380 381
    client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
                                           client, NULL);
    if (client->sockTimer < 0)
        goto error;

382
    /* Prepare one for packet receive */
383
    if (!(client->rx = virNetMessageNew(true)))
384 385
        goto error;
    client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
386
    if (VIR_ALLOC_N(client->rx->buffer, client->rx->bufferLength) < 0)
387
        goto error;
388 389
    client->nrequests = 1;

390
    PROBE(RPC_SERVER_CLIENT_NEW,
391 392
          "client=%p sock=%p",
          client, client->sock);
393 394 395

    return client;

396
 error:
397
    virObjectUnref(client);
398 399 400 401
    return NULL;
}


402 403
virNetServerClientPtr virNetServerClientNew(unsigned long long id,
                                            virNetSocketPtr sock,
404 405 406
                                            int auth,
                                            bool readonly,
                                            size_t nrequests_max,
407
#ifdef WITH_GNUTLS
408
                                            virNetTLSContextPtr tls,
409
#endif
410
                                            virNetServerClientPrivNew privNew,
411
                                            virNetServerClientPrivPreExecRestart privPreExecRestart,
412 413 414 415 416
                                            virFreeCallback privFree,
                                            void *privOpaque)
{
    virNetServerClientPtr client;

417
    VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth,
418
#ifdef WITH_GNUTLS
419 420 421 422 423
              tls
#else
              NULL
#endif
        );
424

425
    if (!(client = virNetServerClientNewInternal(id, sock, auth,
426
#ifdef WITH_GNUTLS
427 428 429
                                                 tls,
#endif
                                                 readonly, nrequests_max)))
430 431 432 433 434 435 436 437
        return NULL;

    if (privNew) {
        if (!(client->privateData = privNew(client, privOpaque))) {
            virObjectUnref(client);
            return NULL;
        }
        client->privateDataFreeFunc = privFree;
438
        client->privateDataPreExecRestart = privPreExecRestart;
439 440 441 442 443 444
    }

    return client;
}


445 446 447 448
virNetServerClientPtr virNetServerClientNewPostExecRestart(virJSONValuePtr object,
                                                           virNetServerClientPrivNewPostExecRestart privNew,
                                                           virNetServerClientPrivPreExecRestart privPreExecRestart,
                                                           virFreeCallback privFree,
449 450
                                                           void *privOpaque,
                                                           void *opaque)
451 452 453 454 455 456 457
{
    virJSONValuePtr child;
    virNetServerClientPtr client = NULL;
    virNetSocketPtr sock;
    int auth;
    bool readonly;
    unsigned int nrequests_max;
458
    unsigned long long id;
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

    if (virJSONValueObjectGetNumberInt(object, "auth", &auth) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing auth field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetBoolean(object, "readonly", &readonly) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing readonly field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetNumberUint(object, "nrequests_max",
                                        (unsigned int *)&nrequests_max) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing nrequests_client_max field in JSON state document"));
        return NULL;
    }

    if (!(child = virJSONValueObjectGet(object, "sock"))) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing sock field in JSON state document"));
        return NULL;
    }

483 484 485 486 487 488 489 490 491 492 493 494
    if (!virJSONValueObjectHasKey(object, "id")) {
        /* no ID found in, a new one must be generated */
        id = virNetServerNextClientID((virNetServerPtr) opaque);
    } else {
        if (virJSONValueObjectGetNumberUlong(object, "id",
                                        (unsigned long long *) &id) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Malformed id field in JSON state document"));
        return NULL;
        }
    }

495 496 497 498 499
    if (!(sock = virNetSocketNewPostExecRestart(child))) {
        virObjectUnref(sock);
        return NULL;
    }

500 501
    if (!(client = virNetServerClientNewInternal(id,
                                                 sock,
502
                                                 auth,
503
#ifdef WITH_GNUTLS
504 505
                                                 NULL,
#endif
506
                                                 readonly,
507
                                                 nrequests_max))) {
508 509 510 511 512 513 514 515 516 517 518
        virObjectUnref(sock);
        return NULL;
    }
    virObjectUnref(sock);

    if (privNew) {
        if (!(child = virJSONValueObjectGet(object, "privateData"))) {
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Missing privateData field in JSON state document"));
            goto error;
        }
519
        if (!(client->privateData = privNew(client, child, privOpaque)))
520 521 522 523 524 525 526 527
            goto error;
        client->privateDataFreeFunc = privFree;
        client->privateDataPreExecRestart = privPreExecRestart;
    }


    return client;

528
 error:
529 530 531 532 533 534 535 536 537 538 539 540 541
    virObjectUnref(client);
    return NULL;
}


virJSONValuePtr virNetServerClientPreExecRestart(virNetServerClientPtr client)
{
    virJSONValuePtr object = virJSONValueNewObject();
    virJSONValuePtr child;

    if (!object)
        return NULL;

542
    virObjectLock(client);
543

544 545 546 547
    if (virJSONValueObjectAppendNumberUlong(object, "id",
                                            client->id) < 0)
        goto error;

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
    if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) < 0)
        goto error;
    if (virJSONValueObjectAppendBoolean(object, "readonly", client->readonly) < 0)
        goto error;
    if (virJSONValueObjectAppendNumberUint(object, "nrequests_max", client->nrequests_max) < 0)
        goto error;

    if (!(child = virNetSocketPreExecRestart(client->sock)))
        goto error;

    if (virJSONValueObjectAppend(object, "sock", child) < 0) {
        virJSONValueFree(child);
        goto error;
    }

563 564 565
    if (client->privateData && client->privateDataPreExecRestart) {
        if (!(child = client->privateDataPreExecRestart(client, client->privateData)))
            goto error;
566

567 568 569 570
        if (virJSONValueObjectAppend(object, "privateData", child) < 0) {
            virJSONValueFree(child);
            goto error;
        }
571 572
    }

573
    virObjectUnlock(client);
574 575
    return object;

576
 error:
577
    virObjectUnlock(client);
578 579 580 581 582
    virJSONValueFree(object);
    return NULL;
}


583 584 585
int virNetServerClientGetAuth(virNetServerClientPtr client)
{
    int auth;
586
    virObjectLock(client);
587
    auth = client->auth;
588
    virObjectUnlock(client);
589 590 591
    return auth;
}

592 593 594 595 596 597 598
void virNetServerClientSetAuth(virNetServerClientPtr client, int auth)
{
    virObjectLock(client);
    client->auth = auth;
    virObjectUnlock(client);
}

599 600 601
bool virNetServerClientGetReadonly(virNetServerClientPtr client)
{
    bool readonly;
602
    virObjectLock(client);
603
    readonly = client->readonly;
604
    virObjectUnlock(client);
605 606 607
    return readonly;
}

608 609 610 611
unsigned long long virNetServerClientGetID(virNetServerClientPtr client)
{
    return client->id;
}
612

613
#ifdef WITH_GNUTLS
614 615 616
bool virNetServerClientHasTLSSession(virNetServerClientPtr client)
{
    bool has;
617
    virObjectLock(client);
618
    has = client->tls ? true : false;
619
    virObjectUnlock(client);
620 621 622
    return has;
}

623 624 625 626 627 628 629 630 631 632

virNetTLSSessionPtr virNetServerClientGetTLSSession(virNetServerClientPtr client)
{
    virNetTLSSessionPtr tls;
    virObjectLock(client);
    tls = client->tls;
    virObjectUnlock(client);
    return tls;
}

633 634 635
int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
{
    int size = 0;
636
    virObjectLock(client);
637 638
    if (client->tls)
        size = virNetTLSSessionGetKeySize(client->tls);
639
    virObjectUnlock(client);
640 641
    return size;
}
642
#endif
643 644 645

int virNetServerClientGetFD(virNetServerClientPtr client)
{
646
    int fd = -1;
647
    virObjectLock(client);
648 649
    if (client->sock)
        fd = virNetSocketGetFD(client->sock);
650
    virObjectUnlock(client);
651 652 653
    return fd;
}

654 655 656 657 658 659 660 661 662 663 664 665

bool virNetServerClientIsLocal(virNetServerClientPtr client)
{
    bool local = false;
    virObjectLock(client);
    if (client->sock)
        local = virNetSocketIsLocal(client->sock);
    virObjectUnlock(client);
    return local;
}


666
int virNetServerClientGetUNIXIdentity(virNetServerClientPtr client,
667 668
                                      uid_t *uid, gid_t *gid, pid_t *pid,
                                      unsigned long long *timestamp)
669
{
670
    int ret = -1;
671
    virObjectLock(client);
672
    if (client->sock)
673 674 675
        ret = virNetSocketGetUNIXIdentity(client->sock,
                                          uid, gid, pid,
                                          timestamp);
676
    virObjectUnlock(client);
677 678 679
    return ret;
}

680

681 682 683 684 685 686 687 688
static virIdentityPtr
virNetServerClientCreateIdentity(virNetServerClientPtr client)
{
    char *username = NULL;
    char *groupname = NULL;
    char *seccontext = NULL;
    virIdentityPtr ret = NULL;

689 690 691
    if (!(ret = virIdentityNew()))
        goto error;

692 693 694 695
    if (client->sock && virNetSocketIsLocal(client->sock)) {
        gid_t gid;
        uid_t uid;
        pid_t pid;
696 697 698 699
        unsigned long long timestamp;
        if (virNetSocketGetUNIXIdentity(client->sock,
                                        &uid, &gid, &pid,
                                        &timestamp) < 0)
700
            goto error;
701 702

        if (!(username = virGetUserName(uid)))
703 704 705 706 707 708
            goto error;
        if (virIdentitySetUNIXUserName(ret, username) < 0)
            goto error;
        if (virIdentitySetUNIXUserID(ret, uid) < 0)
            goto error;

709
        if (!(groupname = virGetGroupName(gid)))
710 711 712 713 714 715 716 717 718 719
            goto error;
        if (virIdentitySetUNIXGroupName(ret, groupname) < 0)
            goto error;
        if (virIdentitySetUNIXGroupID(ret, gid) < 0)
            goto error;

        if (virIdentitySetUNIXProcessID(ret, pid) < 0)
            goto error;
        if (virIdentitySetUNIXProcessTime(ret, timestamp) < 0)
            goto error;
720 721 722 723 724
    }

#if WITH_SASL
    if (client->sasl) {
        const char *identity = virNetSASLSessionGetIdentity(client->sasl);
725 726
        if (virIdentitySetSASLUserName(ret, identity) < 0)
            goto error;
727 728 729
    }
#endif

730
#if WITH_GNUTLS
731 732
    if (client->tls) {
        const char *identity = virNetTLSSessionGetX509DName(client->tls);
733 734
        if (virIdentitySetX509DName(ret, identity) < 0)
            goto error;
735
    }
736
#endif
737 738

    if (client->sock &&
739
        virNetSocketGetSELinuxContext(client->sock, &seccontext) < 0)
740
        goto error;
741
    if (seccontext &&
742
        virIdentitySetSELinuxContext(ret, seccontext) < 0)
743 744
        goto error;

745
 cleanup:
746 747 748 749 750
    VIR_FREE(username);
    VIR_FREE(groupname);
    VIR_FREE(seccontext);
    return ret;

751
 error:
752
    virObjectUnref(ret);
753
    ret = NULL;
754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770
    goto cleanup;
}


virIdentityPtr virNetServerClientGetIdentity(virNetServerClientPtr client)
{
    virIdentityPtr ret = NULL;
    virObjectLock(client);
    if (!client->identity)
        client->identity = virNetServerClientCreateIdentity(client);
    if (client->identity)
        ret = virObjectRef(client->identity);
    virObjectUnlock(client);
    return ret;
}


771 772
int virNetServerClientGetSELinuxContext(virNetServerClientPtr client,
                                        char **context)
773 774 775 776 777
{
    int ret = 0;
    *context = NULL;
    virObjectLock(client);
    if (client->sock)
778
        ret = virNetSocketGetSELinuxContext(client->sock, context);
779 780 781 782 783
    virObjectUnlock(client);
    return ret;
}


784 785 786
bool virNetServerClientIsSecure(virNetServerClientPtr client)
{
    bool secure = false;
787
    virObjectLock(client);
788
#if WITH_GNUTLS
789 790
    if (client->tls)
        secure = true;
791
#endif
792
#if WITH_SASL
793 794 795
    if (client->sasl)
        secure = true;
#endif
796
    if (client->sock && virNetSocketIsLocal(client->sock))
797
        secure = true;
798
    virObjectUnlock(client);
799 800 801 802
    return secure;
}


803
#if WITH_SASL
804 805 806 807 808 809 810 811
void virNetServerClientSetSASLSession(virNetServerClientPtr client,
                                      virNetSASLSessionPtr sasl)
{
    /* We don't set the sasl session on the socket here
     * because we need to send out the auth confirmation
     * in the clear. Only once we complete the next 'tx'
     * operation do we switch to SASL mode
     */
812
    virObjectLock(client);
813
    client->sasl = virObjectRef(sasl);
814
    virObjectUnlock(client);
815
}
816 817 818 819 820 821 822 823 824 825


virNetSASLSessionPtr virNetServerClientGetSASLSession(virNetServerClientPtr client)
{
    virNetSASLSessionPtr sasl;
    virObjectLock(client);
    sasl = client->sasl;
    virObjectUnlock(client);
    return sasl;
}
826 827 828 829 830 831
#endif


void *virNetServerClientGetPrivateData(virNetServerClientPtr client)
{
    void *data;
832
    virObjectLock(client);
833
    data = client->privateData;
834
    virObjectUnlock(client);
835 836 837 838
    return data;
}


839 840 841
void virNetServerClientSetCloseHook(virNetServerClientPtr client,
                                    virNetServerClientCloseFunc cf)
{
842
    virObjectLock(client);
843
    client->privateDataCloseFunc = cf;
844
    virObjectUnlock(client);
845 846 847
}


848 849 850 851
void virNetServerClientSetDispatcher(virNetServerClientPtr client,
                                     virNetServerClientDispatchFunc func,
                                     void *opaque)
{
852
    virObjectLock(client);
853 854
    client->dispatchFunc = func;
    client->dispatchOpaque = opaque;
855
    virObjectUnlock(client);
856 857 858 859 860
}


const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
{
861 862
    if (!client->sock)
        return NULL;
863 864 865 866 867 868
    return virNetSocketLocalAddrString(client->sock);
}


const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
{
869 870
    if (!client->sock)
        return NULL;
871 872 873 874
    return virNetSocketRemoteAddrString(client->sock);
}


875
void virNetServerClientDispose(void *obj)
876
{
877
    virNetServerClientPtr client = obj;
878

879 880 881
    PROBE(RPC_SERVER_CLIENT_DISPOSE,
          "client=%p", client);

882 883 884 885
    if (client->privateData &&
        client->privateDataFreeFunc)
        client->privateDataFreeFunc(client->privateData);

886 887
    virObjectUnref(client->identity);

888
#if WITH_SASL
889
    virObjectUnref(client->sasl);
890
#endif
M
Michal Privoznik 已提交
891 892
    if (client->sockTimer > 0)
        virEventRemoveTimeout(client->sockTimer);
893
#if WITH_GNUTLS
894 895
    virObjectUnref(client->tls);
    virObjectUnref(client->tlsCtxt);
896
#endif
897
    virObjectUnref(client->sock);
898 899 900 901 902 903 904 905 906 907 908 909 910
}


/*
 *
 * We don't free stuff here, merely disconnect the client's
 * network socket & resources.
 *
 * Full free of the client is done later in a safe point
 * where it can be guaranteed it is no longer in use
 */
void virNetServerClientClose(virNetServerClientPtr client)
{
911
    virNetServerClientCloseFunc cf;
912
    virKeepAlivePtr ka;
913

914
    virObjectLock(client);
915
    VIR_DEBUG("client=%p", client);
916
    if (!client->sock) {
917
        virObjectUnlock(client);
918 919 920
        return;
    }

921 922 923 924
    if (client->keepalive) {
        virKeepAliveStop(client->keepalive);
        ka = client->keepalive;
        client->keepalive = NULL;
925
        virObjectRef(client);
926
        virObjectUnlock(client);
927
        virObjectUnref(ka);
928
        virObjectLock(client);
929
        virObjectUnref(client);
930 931
    }

932 933
    if (client->privateDataCloseFunc) {
        cf = client->privateDataCloseFunc;
934
        virObjectRef(client);
935
        virObjectUnlock(client);
936
        (cf)(client);
937
        virObjectLock(client);
938
        virObjectUnref(client);
939 940
    }

941 942 943 944 945 946
    /* Do now, even though we don't close the socket
     * until end, to ensure we don't get invoked
     * again due to tls shutdown */
    if (client->sock)
        virNetSocketRemoveIOCallback(client->sock);

947
#if WITH_GNUTLS
948
    if (client->tls) {
949
        virObjectUnref(client->tls);
950 951
        client->tls = NULL;
    }
952
#endif
953
    client->wantClose = true;
954 955 956 957 958 959 960 961 962 963 964 965

    while (client->rx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->rx);
        virNetMessageFree(msg);
    }
    while (client->tx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->tx);
        virNetMessageFree(msg);
    }

966
    if (client->sock) {
967
        virObjectUnref(client->sock);
968 969 970
        client->sock = NULL;
    }

971
    virObjectUnlock(client);
972 973 974 975 976 977
}


bool virNetServerClientIsClosed(virNetServerClientPtr client)
{
    bool closed;
978
    virObjectLock(client);
979
    closed = client->sock == NULL ? true : false;
980
    virObjectUnlock(client);
981 982 983
    return closed;
}

984 985
void virNetServerClientDelayedClose(virNetServerClientPtr client)
{
986
    virObjectLock(client);
987
    client->delayedClose = true;
988
    virObjectUnlock(client);
989 990 991
}

void virNetServerClientImmediateClose(virNetServerClientPtr client)
992
{
993
    virObjectLock(client);
994
    client->wantClose = true;
995
    virObjectUnlock(client);
996 997 998 999 1000
}

bool virNetServerClientWantClose(virNetServerClientPtr client)
{
    bool wantClose;
1001
    virObjectLock(client);
1002
    wantClose = client->wantClose;
1003
    virObjectUnlock(client);
1004 1005 1006 1007 1008 1009
    return wantClose;
}


int virNetServerClientInit(virNetServerClientPtr client)
{
1010
    virObjectLock(client);
1011

1012
#if WITH_GNUTLS
1013
    if (!client->tlsCtxt) {
1014
#endif
1015 1016 1017
        /* Plain socket, so prepare to read first message */
        if (virNetServerClientRegisterEvent(client) < 0)
            goto error;
1018
#if WITH_GNUTLS
1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
    } else {
        int ret;

        if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
                                                NULL)))
            goto error;

        virNetSocketSetTLSSession(client->sock,
                                  client->tls);

        /* Begin the TLS handshake. */
        ret = virNetTLSSessionHandshake(client->tls);
        if (ret == 0) {
            /* Unlikely, but ...  Next step is to check the certificate. */
            if (virNetServerClientCheckAccess(client) < 0)
                goto error;

            /* Handshake & cert check OK,  so prepare to read first message */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else if (ret > 0) {
            /* Most likely, need to do more handshake data */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else {
            goto error;
        }
    }
1047
#endif
1048

1049
    virObjectUnlock(client);
1050 1051
    return 0;

1052
 error:
1053
    client->wantClose = true;
1054
    virObjectUnlock(client);
1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072
    return -1;
}



/*
 * Read data into buffer using wire decoding (plain or TLS)
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientRead(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->rx->bufferLength <= client->rx->bufferOffset) {
1073 1074 1075
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->rx->bufferLength - client->rx->bufferOffset));
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
        client->wantClose = true;
        return -1;
    }

    ret = virNetSocketRead(client->sock,
                           client->rx->buffer + client->rx->bufferOffset,
                           client->rx->bufferLength - client->rx->bufferOffset);

    if (ret <= 0)
        return ret;

    client->rx->bufferOffset += ret;
    return ret;
}


/*
 * Read data until we get a complete message to process
 */
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
1097
 readmore:
1098 1099 1100 1101 1102
    if (client->rx->nfds == 0) {
        if (virNetServerClientRead(client) < 0) {
            client->wantClose = true;
            return; /* Error */
        }
1103 1104 1105 1106 1107 1108 1109
    }

    if (client->rx->bufferOffset < client->rx->bufferLength)
        return; /* Still not read enough */

    /* Either done with length word header */
    if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) {
1110 1111
        if (virNetMessageDecodeLength(client->rx) < 0) {
            client->wantClose = true;
1112
            return;
1113
        }
1114 1115 1116 1117 1118 1119 1120 1121 1122

        virNetServerClientUpdateEvent(client);

        /* Try and read payload immediately instead of going back
           into poll() because chances are the data is already
           waiting for us */
        goto readmore;
    } else {
        /* Grab the completed message */
1123
        virNetMessagePtr msg = client->rx;
1124
        virNetMessagePtr response = NULL;
1125
        virNetServerClientFilterPtr filter;
1126
        size_t i;
1127 1128 1129

        /* Decode the header so we can use it for routing decisions */
        if (virNetMessageDecodeHeader(msg) < 0) {
1130
            virNetMessageQueueServe(&client->rx);
1131 1132 1133 1134 1135
            virNetMessageFree(msg);
            client->wantClose = true;
            return;
        }

1136 1137
        /* Now figure out if we need to read more data to get some
         * file descriptors */
1138 1139 1140
        if (msg->header.type == VIR_NET_CALL_WITH_FDS) {
            if (msg->nfds == 0 &&
                virNetMessageDecodeNumFDs(msg) < 0) {
1141
                virNetMessageQueueServe(&client->rx);
1142 1143
                virNetMessageFree(msg);
                client->wantClose = true;
1144
                return; /* Error */
1145
            }
1146

1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169
            /* Try getting the file descriptors (may fail if blocking) */
            for (i = msg->donefds; i < msg->nfds; i++) {
                int rv;
                if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
                    virNetMessageQueueServe(&client->rx);
                    virNetMessageFree(msg);
                    client->wantClose = true;
                    return;
                }
                if (rv == 0) /* Blocking */
                    break;
                msg->donefds++;
            }

            /* Need to poll() until FDs arrive */
            if (msg->donefds < msg->nfds) {
                /* Because DecodeHeader/NumFDs reset bufferOffset, we
                 * put it back to what it was, so everything works
                 * again next time we run this method
                 */
                client->rx->bufferOffset = client->rx->bufferLength;
                return;
            }
1170 1171
        }

1172 1173
        /* Definitely finished reading, so remove from queue */
        virNetMessageQueueServe(&client->rx);
1174 1175 1176 1177 1178 1179
        PROBE(RPC_SERVER_CLIENT_MSG_RX,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);

1180 1181 1182 1183 1184 1185 1186 1187 1188 1189
        if (virKeepAliveCheckMessage(client->keepalive, msg, &response)) {
            virNetMessageFree(msg);
            client->nrequests--;
            msg = NULL;

            if (response &&
                virNetServerClientSendMessageLocked(client, response) < 0)
                virNetMessageFree(response);
        }

1190
        /* Maybe send off for queue against a filter */
1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
        if (msg) {
            filter = client->filters;
            while (filter) {
                int ret = filter->func(client, msg, filter->opaque);
                if (ret < 0) {
                    virNetMessageFree(msg);
                    msg = NULL;
                    if (ret < 0)
                        client->wantClose = true;
                    break;
                }
                if (ret > 0) {
                    msg = NULL;
                    break;
                }
1206

1207 1208
                filter = filter->next;
            }
1209 1210 1211 1212
        }

        /* Send off to for normal dispatch to workers */
        if (msg) {
1213
            virObjectRef(client);
1214 1215 1216 1217
            if (!client->dispatchFunc ||
                client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) {
                virNetMessageFree(msg);
                client->wantClose = true;
1218
                virObjectUnref(client);
1219 1220 1221 1222 1223 1224
                return;
            }
        }

        /* Possibly need to create another receive buffer */
        if (client->nrequests < client->nrequests_max) {
1225
            if (!(client->rx = virNetMessageNew(true))) {
1226
                client->wantClose = true;
E
Eric Blake 已提交
1227 1228
            } else {
                client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
1229 1230 1231 1232 1233 1234
                if (VIR_ALLOC_N(client->rx->buffer,
                                client->rx->bufferLength) < 0) {
                    client->wantClose = true;
                } else {
                    client->nrequests++;
                }
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254
            }
        }
        virNetServerClientUpdateEvent(client);
    }
}


/*
 * Send client->tx using no encoding
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientWrite(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->tx->bufferLength < client->tx->bufferOffset) {
1255 1256 1257
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->tx->bufferLength - client->tx->bufferOffset));
1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283
        client->wantClose = true;
        return -1;
    }

    if (client->tx->bufferLength == client->tx->bufferOffset)
        return 1;

    ret = virNetSocketWrite(client->sock,
                            client->tx->buffer + client->tx->bufferOffset,
                            client->tx->bufferLength - client->tx->bufferOffset);
    if (ret <= 0)
        return ret; /* -1 error, 0 = egain */

    client->tx->bufferOffset += ret;
    return ret;
}


/*
 * Process all queued client->tx messages until
 * we would block on I/O
 */
static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
    while (client->tx) {
1284 1285 1286 1287 1288 1289 1290 1291 1292
        if (client->tx->bufferOffset < client->tx->bufferLength) {
            ssize_t ret;
            ret = virNetServerClientWrite(client);
            if (ret < 0) {
                client->wantClose = true;
                return;
            }
            if (ret == 0)
                return; /* Would block on write EAGAIN */
1293 1294 1295 1296
        }

        if (client->tx->bufferOffset == client->tx->bufferLength) {
            virNetMessagePtr msg;
1297 1298
            size_t i;

1299
            for (i = client->tx->donefds; i < client->tx->nfds; i++) {
1300 1301
                int rv;
                if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
1302 1303 1304
                    client->wantClose = true;
                    return;
                }
1305 1306 1307
                if (rv == 0) /* Blocking */
                    return;
                client->tx->donefds++;
1308 1309
            }

1310
#if WITH_SASL
1311 1312 1313 1314 1315
            /* Completed this 'tx' operation, so now read for all
             * future rx/tx to be under a SASL SSF layer
             */
            if (client->sasl) {
                virNetSocketSetSASLSession(client->sock, client->sasl);
1316
                virObjectUnref(client->sasl);
1317 1318 1319 1320 1321 1322 1323
                client->sasl = NULL;
            }
#endif

            /* Get finished msg from head of tx queue */
            msg = virNetMessageQueueServe(&client->tx);

1324
            if (msg->tracked) {
1325 1326 1327 1328 1329
                client->nrequests--;
                /* See if the recv queue is currently throttled */
                if (!client->rx &&
                    client->nrequests < client->nrequests_max) {
                    /* Ready to recv more messages */
1330
                    virNetMessageClear(msg);
1331 1332 1333 1334 1335
                    msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
                    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
                        virNetMessageFree(msg);
                        return;
                    }
1336 1337 1338 1339 1340 1341 1342 1343 1344
                    client->rx = msg;
                    msg = NULL;
                    client->nrequests++;
                }
            }

            virNetMessageFree(msg);

            virNetServerClientUpdateEvent(client);
1345 1346 1347

            if (client->delayedClose)
                client->wantClose = true;
1348 1349 1350 1351
         }
    }
}

1352

1353
#if WITH_GNUTLS
1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369
static void
virNetServerClientDispatchHandshake(virNetServerClientPtr client)
{
    int ret;
    /* Continue the handshake. */
    ret = virNetTLSSessionHandshake(client->tls);
    if (ret == 0) {
        /* Finished.  Next step is to check the certificate. */
        if (virNetServerClientCheckAccess(client) < 0)
            client->wantClose = true;
        else
            virNetServerClientUpdateEvent(client);
    } else if (ret > 0) {
        /* Carry on waiting for more handshake. Update
           the events just in case handshake data flow
           direction has changed */
1370
        virNetServerClientUpdateEvent(client);
1371 1372 1373 1374 1375
    } else {
        /* Fatal error in handshake */
        client->wantClose = true;
    }
}
1376
#endif
1377 1378 1379 1380 1381 1382

static void
virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque)
{
    virNetServerClientPtr client = opaque;

1383
    virObjectLock(client);
1384 1385 1386

    if (client->sock != sock) {
        virNetSocketRemoveIOCallback(sock);
1387
        virObjectUnlock(client);
1388 1389 1390 1391 1392
        return;
    }

    if (events & (VIR_EVENT_HANDLE_WRITABLE |
                  VIR_EVENT_HANDLE_READABLE)) {
1393
#if WITH_GNUTLS
1394 1395 1396 1397 1398
        if (client->tls &&
            virNetTLSSessionGetHandshakeStatus(client->tls) !=
            VIR_NET_TLS_HANDSHAKE_COMPLETE) {
            virNetServerClientDispatchHandshake(client);
        } else {
1399
#endif
1400 1401
            if (events & VIR_EVENT_HANDLE_WRITABLE)
                virNetServerClientDispatchWrite(client);
M
Michal Privoznik 已提交
1402 1403
            if (events & VIR_EVENT_HANDLE_READABLE &&
                client->rx)
1404
                virNetServerClientDispatchRead(client);
1405
#if WITH_GNUTLS
1406
        }
1407
#endif
1408 1409 1410 1411 1412 1413 1414 1415
    }

    /* NB, will get HANGUP + READABLE at same time upon
     * disconnect */
    if (events & (VIR_EVENT_HANDLE_ERROR |
                  VIR_EVENT_HANDLE_HANGUP))
        client->wantClose = true;

1416
    virObjectUnlock(client);
1417 1418 1419
}


1420 1421 1422
static int
virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                    virNetMessagePtr msg)
1423 1424 1425 1426 1427
{
    int ret = -1;
    VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu",
              msg, msg->header.proc,
              msg->bufferLength, msg->bufferOffset);
1428

1429
    msg->donefds = 0;
1430
    if (client->sock && !client->wantClose) {
1431 1432 1433 1434 1435
        PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);
1436 1437 1438 1439 1440 1441
        virNetMessageQueuePush(&client->tx, msg);

        virNetServerClientUpdateEvent(client);
        ret = 0;
    }

1442 1443 1444 1445 1446 1447 1448 1449
    return ret;
}

int virNetServerClientSendMessage(virNetServerClientPtr client,
                                  virNetMessagePtr msg)
{
    int ret;

1450
    virObjectLock(client);
1451
    ret = virNetServerClientSendMessageLocked(client, msg);
1452
    virObjectUnlock(client);
1453

1454 1455 1456 1457 1458 1459 1460
    return ret;
}


bool virNetServerClientNeedAuth(virNetServerClientPtr client)
{
    bool need = false;
1461
    virObjectLock(client);
1462
    if (client->auth)
1463
        need = true;
1464
    virObjectUnlock(client);
1465 1466
    return need;
}
1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490


static void
virNetServerClientKeepAliveDeadCB(void *opaque)
{
    virNetServerClientImmediateClose(opaque);
}

static int
virNetServerClientKeepAliveSendCB(void *opaque,
                                  virNetMessagePtr msg)
{
    return virNetServerClientSendMessage(opaque, msg);
}


int
virNetServerClientInitKeepAlive(virNetServerClientPtr client,
                                int interval,
                                unsigned int count)
{
    virKeepAlivePtr ka;
    int ret = -1;

1491
    virObjectLock(client);
1492 1493 1494 1495

    if (!(ka = virKeepAliveNew(interval, count, client,
                               virNetServerClientKeepAliveSendCB,
                               virNetServerClientKeepAliveDeadCB,
1496
                               virObjectFreeCallback)))
1497 1498
        goto cleanup;
    /* keepalive object has a reference to client */
1499
    virObjectRef(client);
1500 1501 1502

    client->keepalive = ka;

1503
 cleanup:
1504
    virObjectUnlock(client);
1505 1506 1507 1508 1509 1510 1511

    return ret;
}

int
virNetServerClientStartKeepAlive(virNetServerClientPtr client)
{
1512 1513
    int ret = -1;

1514
    virObjectLock(client);
1515 1516 1517 1518

    /* The connection might have been closed before we got here and thus the
     * keepalive object could have been removed too.
     */
1519
    if (!client->keepalive) {
1520 1521 1522 1523 1524
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("connection not open"));
        goto cleanup;
    }

1525
    ret = virKeepAliveStart(client->keepalive, 0, 0);
1526

1527
 cleanup:
1528
    virObjectUnlock(client);
1529 1530
    return ret;
}