virnetserverclient.c 44.0 KB
Newer Older
1 2 3
/*
 * virnetserverclient.c: generic network RPC server client
 *
4
 * Copyright (C) 2006-2014 Red Hat, Inc.
5 6 7 8 9 10 11 12 13 14 15 16 17
 * Copyright (C) 2006 Daniel P. Berrange
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
18
 * License along with this library.  If not, see
O
Osier Yang 已提交
19
 * <http://www.gnu.org/licenses/>.
20 21 22 23 24 25
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */

#include <config.h>

26
#include "internal.h"
27
#if WITH_SASL
28 29 30
# include <sasl/sasl.h>
#endif

31
#include "virnetserver.h"
32 33
#include "virnetserverclient.h"

34
#include "virlog.h"
35
#include "virerror.h"
36
#include "viralloc.h"
37
#include "virthread.h"
38
#include "virkeepalive.h"
39
#include "virprobe.h"
40 41
#include "virstring.h"
#include "virutil.h"
42 43 44

#define VIR_FROM_THIS VIR_FROM_RPC

45 46
VIR_LOG_INIT("rpc.netserverclient");

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/* Allow for filtering of incoming messages to a custom
 * dispatch processing queue, instead of the workers.
 * This allows for certain types of messages to be handled
 * strictly "in order"
 */

typedef struct _virNetServerClientFilter virNetServerClientFilter;
typedef virNetServerClientFilter *virNetServerClientFilterPtr;

struct _virNetServerClientFilter {
    int id;
    virNetServerClientFilterFunc func;
    void *opaque;

    virNetServerClientFilterPtr next;
};


struct _virNetServerClient
{
67
    virObjectLockable parent;
68

69
    unsigned long long id;
70
    bool wantClose;
71
    bool delayedClose;
72 73 74
    virNetSocketPtr sock;
    int auth;
    bool readonly;
75
#if WITH_GNUTLS
76 77
    virNetTLSContextPtr tlsCtxt;
    virNetTLSSessionPtr tls;
78
#endif
79
#if WITH_SASL
80 81
    virNetSASLSessionPtr sasl;
#endif
M
Michal Privoznik 已提交
82 83
    int sockTimer; /* Timer to be fired upon cached data,
                    * so we jump out from poll() immediately */
84

85 86 87

    virIdentityPtr identity;

88 89 90 91 92 93 94
    /* Connection timestamp, i.e. when a client connected to the daemon (UTC).
     * For old clients restored by post-exec-restart, which did not have this
     * attribute, value of 0 (epoch time) is used to indicate we have no
     * information about their connection time.
     */
    time_t conn_time;

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    /* Count of messages in the 'tx' queue,
     * and the server worker pool queue
     * ie RPC calls in progress. Does not count
     * async events which are not used for
     * throttling calculations */
    size_t nrequests;
    size_t nrequests_max;
    /* Zero or one messages being received. Zero if
     * nrequests >= max_clients and throttling */
    virNetMessagePtr rx;
    /* Zero or many messages waiting for transmit
     * back to client, including async events */
    virNetMessagePtr tx;

    /* Filters to capture messages that would otherwise
     * end up on the 'dx' queue */
    virNetServerClientFilterPtr filters;
    int nextFilterID;

    virNetServerClientDispatchFunc dispatchFunc;
    void *dispatchOpaque;

    void *privateData;
118
    virFreeCallback privateDataFreeFunc;
119
    virNetServerClientPrivPreExecRestart privateDataPreExecRestart;
120
    virNetServerClientCloseFunc privateDataCloseFunc;
121 122

    virKeepAlivePtr keepalive;
123 124 125
};


126 127 128 129 130
static virClassPtr virNetServerClientClass;
static void virNetServerClientDispose(void *obj);

static int virNetServerClientOnceInit(void)
{
131
    if (!(virNetServerClientClass = virClassNew(virClassForObjectLockable(),
132
                                                "virNetServerClient",
133 134 135 136 137 138 139 140 141 142
                                                sizeof(virNetServerClient),
                                                virNetServerClientDispose)))
        return -1;

    return 0;
}

VIR_ONCE_GLOBAL_INIT(virNetServerClient)


143 144
static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque);
static void virNetServerClientUpdateEvent(virNetServerClientPtr client);
M
Michal Privoznik 已提交
145
static void virNetServerClientDispatchRead(virNetServerClientPtr client);
146 147
static int virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                               virNetMessagePtr msg);
148 149 150 151 152

/*
 * @client: a locked client object
 */
static int
153 154
virNetServerClientCalculateHandleMode(virNetServerClientPtr client)
{
155 156 157 158
    int mode = 0;


    VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p",
159
#ifdef WITH_GNUTLS
160 161
              client->tls,
              client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1,
162 163 164
#else
              NULL, -1,
#endif
165 166 167 168 169
              client->rx,
              client->tx);
    if (!client->sock || client->wantClose)
        return 0;

170
#if WITH_GNUTLS
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
    if (client->tls) {
        switch (virNetTLSSessionGetHandshakeStatus(client->tls)) {
        case VIR_NET_TLS_HANDSHAKE_RECVING:
            mode |= VIR_EVENT_HANDLE_READABLE;
            break;
        case VIR_NET_TLS_HANDSHAKE_SENDING:
            mode |= VIR_EVENT_HANDLE_WRITABLE;
            break;
        default:
        case VIR_NET_TLS_HANDSHAKE_COMPLETE:
            if (client->rx)
                mode |= VIR_EVENT_HANDLE_READABLE;
            if (client->tx)
                mode |= VIR_EVENT_HANDLE_WRITABLE;
        }
    } else {
187
#endif
188 189
        /* If there is a message on the rx queue, and
         * we're not in middle of a delayedClose, then
190
         * we're wanting more input */
191
        if (client->rx && !client->delayedClose)
192 193 194 195 196 197
            mode |= VIR_EVENT_HANDLE_READABLE;

        /* If there are one or more messages to send back to client,
           then monitor for writability on socket */
        if (client->tx)
            mode |= VIR_EVENT_HANDLE_WRITABLE;
198
#if WITH_GNUTLS
199
    }
200
#endif
201
    VIR_DEBUG("mode=%o", mode);
202 203 204 205 206 207 208 209 210 211 212
    return mode;
}

/*
 * @server: a locked or unlocked server object
 * @client: a locked client object
 */
static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
{
    int mode = virNetServerClientCalculateHandleMode(client);

213 214 215 216
    if (!client->sock)
        return -1;

    virObjectRef(client);
217
    VIR_DEBUG("Registering client event callback %d", mode);
218
    if (virNetSocketAddIOCallback(client->sock,
219 220
                                  mode,
                                  virNetServerClientDispatchEvent,
221
                                  client,
222 223
                                  virObjectFreeCallback) < 0) {
        virObjectUnref(client);
224
        return -1;
225
    }
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

    return 0;
}

/*
 * @client: a locked client object
 */
static void virNetServerClientUpdateEvent(virNetServerClientPtr client)
{
    int mode;

    if (!client->sock)
        return;

    mode = virNetServerClientCalculateHandleMode(client);

    virNetSocketUpdateIOCallback(client->sock, mode);
M
Michal Privoznik 已提交
243 244 245

    if (client->rx && virNetSocketHasCachedData(client->sock))
        virEventUpdateTimeout(client->sockTimer, 0);
246 247 248
}


249 250 251
int virNetServerClientAddFilter(virNetServerClientPtr client,
                                virNetServerClientFilterFunc func,
                                void *opaque)
252 253
{
    virNetServerClientFilterPtr filter;
254
    virNetServerClientFilterPtr *place;
255
    int ret;
256

257
    if (VIR_ALLOC(filter) < 0)
258
        return -1;
259

260
    virObjectLock(client);
261

262 263 264 265
    filter->id = client->nextFilterID++;
    filter->func = func;
    filter->opaque = opaque;

266 267 268 269
    place = &client->filters;
    while (*place)
        place = &(*place)->next;
    *place = filter;
270 271 272

    ret = filter->id;

273
    virObjectUnlock(client);
274

275 276 277
    return ret;
}

278 279
void virNetServerClientRemoveFilter(virNetServerClientPtr client,
                                    int filterID)
280 281 282
{
    virNetServerClientFilterPtr tmp, prev;

283
    virObjectLock(client);
284

285 286 287 288 289 290 291 292 293 294 295 296
    prev = NULL;
    tmp = client->filters;
    while (tmp) {
        if (tmp->id == filterID) {
            if (prev)
                prev->next = tmp->next;
            else
                client->filters = tmp->next;

            VIR_FREE(tmp);
            break;
        }
E
Eric Blake 已提交
297
        prev = tmp;
298 299 300
        tmp = tmp->next;
    }

301
    virObjectUnlock(client);
302 303 304
}


305
#ifdef WITH_GNUTLS
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
/* Check the client's access. */
static int
virNetServerClientCheckAccess(virNetServerClientPtr client)
{
    virNetMessagePtr confirm;

    /* Verify client certificate. */
    if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0)
        return -1;

    if (client->tx) {
        VIR_DEBUG("client had unexpected data pending tx after access check");
        return -1;
    }

321
    if (!(confirm = virNetMessageNew(false)))
322 323 324 325 326 327 328
        return -1;

    /* Checks have succeeded.  Write a '\1' byte back to the client to
     * indicate this (otherwise the socket is abruptly closed).
     * (NB. The '\1' byte is sent in an encrypted record).
     */
    confirm->bufferLength = 1;
329 330 331 332
    if (VIR_ALLOC_N(confirm->buffer, confirm->bufferLength) < 0) {
        virNetMessageFree(confirm);
        return -1;
    }
333 334 335 336 337 338 339
    confirm->bufferOffset = 0;
    confirm->buffer[0] = '\1';

    client->tx = confirm;

    return 0;
}
340 341
#endif

342

M
Michal Privoznik 已提交
343 344 345 346
static void virNetServerClientSockTimerFunc(int timer,
                                            void *opaque)
{
    virNetServerClientPtr client = opaque;
347
    virObjectLock(client);
M
Michal Privoznik 已提交
348 349 350 351 352
    virEventUpdateTimeout(timer, -1);
    /* Although client->rx != NULL when this timer is enabled, it might have
     * changed since the client was unlocked in the meantime. */
    if (client->rx)
        virNetServerClientDispatchRead(client);
353
    virObjectUnlock(client);
M
Michal Privoznik 已提交
354 355
}

356

357
static virNetServerClientPtr
358 359
virNetServerClientNewInternal(unsigned long long id,
                              virNetSocketPtr sock,
360
                              int auth,
361
#ifdef WITH_GNUTLS
362 363
                              virNetTLSContextPtr tls,
#endif
364
                              bool readonly,
365 366
                              size_t nrequests_max,
                              time_t timestamp)
367 368 369
{
    virNetServerClientPtr client;

370
    if (virNetServerClientInitialize() < 0)
371 372
        return NULL;

373
    if (!(client = virObjectLockableNew(virNetServerClientClass)))
374 375
        return NULL;

376
    client->id = id;
377
    client->sock = virObjectRef(sock);
378 379
    client->auth = auth;
    client->readonly = readonly;
380
#ifdef WITH_GNUTLS
381
    client->tlsCtxt = virObjectRef(tls);
382
#endif
383
    client->nrequests_max = nrequests_max;
384
    client->conn_time = timestamp;
385

M
Michal Privoznik 已提交
386 387 388 389 390
    client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
                                           client, NULL);
    if (client->sockTimer < 0)
        goto error;

391
    /* Prepare one for packet receive */
392
    if (!(client->rx = virNetMessageNew(true)))
393 394
        goto error;
    client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
395
    if (VIR_ALLOC_N(client->rx->buffer, client->rx->bufferLength) < 0)
396
        goto error;
397 398
    client->nrequests = 1;

399
    PROBE(RPC_SERVER_CLIENT_NEW,
400 401
          "client=%p sock=%p",
          client, client->sock);
402 403 404

    return client;

405
 error:
406
    virObjectUnref(client);
407 408 409 410
    return NULL;
}


411 412
virNetServerClientPtr virNetServerClientNew(unsigned long long id,
                                            virNetSocketPtr sock,
413 414 415
                                            int auth,
                                            bool readonly,
                                            size_t nrequests_max,
416
#ifdef WITH_GNUTLS
417
                                            virNetTLSContextPtr tls,
418
#endif
419
                                            virNetServerClientPrivNew privNew,
420
                                            virNetServerClientPrivPreExecRestart privPreExecRestart,
421 422 423 424
                                            virFreeCallback privFree,
                                            void *privOpaque)
{
    virNetServerClientPtr client;
425
    time_t now;
426

427
    VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth,
428
#ifdef WITH_GNUTLS
429 430 431 432 433
              tls
#else
              NULL
#endif
        );
434

435 436 437 438 439
    if ((now = time(NULL)) == (time_t) - 1) {
        virReportSystemError(errno, "%s", _("failed to get current time"));
        return NULL;
    }

440
    if (!(client = virNetServerClientNewInternal(id, sock, auth,
441
#ifdef WITH_GNUTLS
442 443
                                                 tls,
#endif
444 445
                                                 readonly, nrequests_max,
                                                 now)))
446 447 448 449 450 451 452 453
        return NULL;

    if (privNew) {
        if (!(client->privateData = privNew(client, privOpaque))) {
            virObjectUnref(client);
            return NULL;
        }
        client->privateDataFreeFunc = privFree;
454
        client->privateDataPreExecRestart = privPreExecRestart;
455 456 457 458 459 460
    }

    return client;
}


461 462 463 464
virNetServerClientPtr virNetServerClientNewPostExecRestart(virJSONValuePtr object,
                                                           virNetServerClientPrivNewPostExecRestart privNew,
                                                           virNetServerClientPrivPreExecRestart privPreExecRestart,
                                                           virFreeCallback privFree,
465 466
                                                           void *privOpaque,
                                                           void *opaque)
467 468 469 470 471 472 473
{
    virJSONValuePtr child;
    virNetServerClientPtr client = NULL;
    virNetSocketPtr sock;
    int auth;
    bool readonly;
    unsigned int nrequests_max;
474
    unsigned long long id;
475
    time_t timestamp;
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499

    if (virJSONValueObjectGetNumberInt(object, "auth", &auth) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing auth field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetBoolean(object, "readonly", &readonly) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing readonly field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetNumberUint(object, "nrequests_max",
                                        (unsigned int *)&nrequests_max) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing nrequests_client_max field in JSON state document"));
        return NULL;
    }

    if (!(child = virJSONValueObjectGet(object, "sock"))) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing sock field in JSON state document"));
        return NULL;
    }

500 501 502 503 504 505 506 507 508 509 510 511
    if (!virJSONValueObjectHasKey(object, "id")) {
        /* no ID found in, a new one must be generated */
        id = virNetServerNextClientID((virNetServerPtr) opaque);
    } else {
        if (virJSONValueObjectGetNumberUlong(object, "id",
                                        (unsigned long long *) &id) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Malformed id field in JSON state document"));
        return NULL;
        }
    }

512 513 514 515 516 517 518 519 520 521 522 523
    if (!virJSONValueObjectHasKey(object, "conn_time")) {
        timestamp = 0;
    } else {
        if (virJSONValueObjectGetNumberLong(object, "conn_time",
                                            (long long *) &timestamp) < 0) {
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Malformed conn_time field in JSON "
                             "state document"));
            return NULL;
        }
    }

524 525 526 527 528
    if (!(sock = virNetSocketNewPostExecRestart(child))) {
        virObjectUnref(sock);
        return NULL;
    }

529 530
    if (!(client = virNetServerClientNewInternal(id,
                                                 sock,
531
                                                 auth,
532
#ifdef WITH_GNUTLS
533 534
                                                 NULL,
#endif
535
                                                 readonly,
536 537
                                                 nrequests_max,
                                                 timestamp))) {
538 539 540 541 542 543 544 545 546 547 548
        virObjectUnref(sock);
        return NULL;
    }
    virObjectUnref(sock);

    if (privNew) {
        if (!(child = virJSONValueObjectGet(object, "privateData"))) {
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Missing privateData field in JSON state document"));
            goto error;
        }
549
        if (!(client->privateData = privNew(client, child, privOpaque)))
550 551 552 553 554 555 556 557
            goto error;
        client->privateDataFreeFunc = privFree;
        client->privateDataPreExecRestart = privPreExecRestart;
    }


    return client;

558
 error:
559 560 561 562 563 564 565 566 567 568 569 570 571
    virObjectUnref(client);
    return NULL;
}


virJSONValuePtr virNetServerClientPreExecRestart(virNetServerClientPtr client)
{
    virJSONValuePtr object = virJSONValueNewObject();
    virJSONValuePtr child;

    if (!object)
        return NULL;

572
    virObjectLock(client);
573

574 575 576 577
    if (virJSONValueObjectAppendNumberUlong(object, "id",
                                            client->id) < 0)
        goto error;

578 579 580 581 582 583 584
    if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) < 0)
        goto error;
    if (virJSONValueObjectAppendBoolean(object, "readonly", client->readonly) < 0)
        goto error;
    if (virJSONValueObjectAppendNumberUint(object, "nrequests_max", client->nrequests_max) < 0)
        goto error;

585 586 587 588 589
    if (client->conn_time &&
        virJSONValueObjectAppendNumberLong(object, "conn_time",
                                           client->conn_time) < 0)
        goto error;

590 591 592 593 594 595 596 597
    if (!(child = virNetSocketPreExecRestart(client->sock)))
        goto error;

    if (virJSONValueObjectAppend(object, "sock", child) < 0) {
        virJSONValueFree(child);
        goto error;
    }

598 599 600
    if (client->privateData && client->privateDataPreExecRestart) {
        if (!(child = client->privateDataPreExecRestart(client, client->privateData)))
            goto error;
601

602 603 604 605
        if (virJSONValueObjectAppend(object, "privateData", child) < 0) {
            virJSONValueFree(child);
            goto error;
        }
606 607
    }

608
    virObjectUnlock(client);
609 610
    return object;

611
 error:
612
    virObjectUnlock(client);
613 614 615 616 617
    virJSONValueFree(object);
    return NULL;
}


618 619 620
int virNetServerClientGetAuth(virNetServerClientPtr client)
{
    int auth;
621
    virObjectLock(client);
622
    auth = client->auth;
623
    virObjectUnlock(client);
624 625 626
    return auth;
}

627 628 629 630 631 632 633
void virNetServerClientSetAuth(virNetServerClientPtr client, int auth)
{
    virObjectLock(client);
    client->auth = auth;
    virObjectUnlock(client);
}

634 635 636
bool virNetServerClientGetReadonly(virNetServerClientPtr client)
{
    bool readonly;
637
    virObjectLock(client);
638
    readonly = client->readonly;
639
    virObjectUnlock(client);
640 641 642
    return readonly;
}

643 644 645 646
unsigned long long virNetServerClientGetID(virNetServerClientPtr client)
{
    return client->id;
}
647

648 649 650 651 652
long long virNetServerClientGetTimestamp(virNetServerClientPtr client)
{
    return client->conn_time;
}

653
#ifdef WITH_GNUTLS
654 655 656
bool virNetServerClientHasTLSSession(virNetServerClientPtr client)
{
    bool has;
657
    virObjectLock(client);
658
    has = client->tls ? true : false;
659
    virObjectUnlock(client);
660 661 662
    return has;
}

663 664 665 666 667 668 669 670 671 672

virNetTLSSessionPtr virNetServerClientGetTLSSession(virNetServerClientPtr client)
{
    virNetTLSSessionPtr tls;
    virObjectLock(client);
    tls = client->tls;
    virObjectUnlock(client);
    return tls;
}

673 674 675
int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
{
    int size = 0;
676
    virObjectLock(client);
677 678
    if (client->tls)
        size = virNetTLSSessionGetKeySize(client->tls);
679
    virObjectUnlock(client);
680 681
    return size;
}
682
#endif
683 684 685

int virNetServerClientGetFD(virNetServerClientPtr client)
{
686
    int fd = -1;
687
    virObjectLock(client);
688 689
    if (client->sock)
        fd = virNetSocketGetFD(client->sock);
690
    virObjectUnlock(client);
691 692 693
    return fd;
}

694 695 696 697 698 699 700 701 702 703 704 705

bool virNetServerClientIsLocal(virNetServerClientPtr client)
{
    bool local = false;
    virObjectLock(client);
    if (client->sock)
        local = virNetSocketIsLocal(client->sock);
    virObjectUnlock(client);
    return local;
}


706
int virNetServerClientGetUNIXIdentity(virNetServerClientPtr client,
707 708
                                      uid_t *uid, gid_t *gid, pid_t *pid,
                                      unsigned long long *timestamp)
709
{
710
    int ret = -1;
711
    virObjectLock(client);
712
    if (client->sock)
713 714 715
        ret = virNetSocketGetUNIXIdentity(client->sock,
                                          uid, gid, pid,
                                          timestamp);
716
    virObjectUnlock(client);
717 718 719
    return ret;
}

720

721 722 723 724 725 726 727 728
static virIdentityPtr
virNetServerClientCreateIdentity(virNetServerClientPtr client)
{
    char *username = NULL;
    char *groupname = NULL;
    char *seccontext = NULL;
    virIdentityPtr ret = NULL;

729 730 731
    if (!(ret = virIdentityNew()))
        goto error;

732 733 734 735
    if (client->sock && virNetSocketIsLocal(client->sock)) {
        gid_t gid;
        uid_t uid;
        pid_t pid;
736 737 738 739
        unsigned long long timestamp;
        if (virNetSocketGetUNIXIdentity(client->sock,
                                        &uid, &gid, &pid,
                                        &timestamp) < 0)
740
            goto error;
741 742

        if (!(username = virGetUserName(uid)))
743 744 745 746 747 748
            goto error;
        if (virIdentitySetUNIXUserName(ret, username) < 0)
            goto error;
        if (virIdentitySetUNIXUserID(ret, uid) < 0)
            goto error;

749
        if (!(groupname = virGetGroupName(gid)))
750 751 752 753 754 755 756 757 758 759
            goto error;
        if (virIdentitySetUNIXGroupName(ret, groupname) < 0)
            goto error;
        if (virIdentitySetUNIXGroupID(ret, gid) < 0)
            goto error;

        if (virIdentitySetUNIXProcessID(ret, pid) < 0)
            goto error;
        if (virIdentitySetUNIXProcessTime(ret, timestamp) < 0)
            goto error;
760 761 762 763 764
    }

#if WITH_SASL
    if (client->sasl) {
        const char *identity = virNetSASLSessionGetIdentity(client->sasl);
765 766
        if (virIdentitySetSASLUserName(ret, identity) < 0)
            goto error;
767 768 769
    }
#endif

770
#if WITH_GNUTLS
771 772
    if (client->tls) {
        const char *identity = virNetTLSSessionGetX509DName(client->tls);
773 774
        if (virIdentitySetX509DName(ret, identity) < 0)
            goto error;
775
    }
776
#endif
777 778

    if (client->sock &&
779
        virNetSocketGetSELinuxContext(client->sock, &seccontext) < 0)
780
        goto error;
781
    if (seccontext &&
782
        virIdentitySetSELinuxContext(ret, seccontext) < 0)
783 784
        goto error;

785
 cleanup:
786 787 788 789 790
    VIR_FREE(username);
    VIR_FREE(groupname);
    VIR_FREE(seccontext);
    return ret;

791
 error:
792
    virObjectUnref(ret);
793
    ret = NULL;
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
    goto cleanup;
}


virIdentityPtr virNetServerClientGetIdentity(virNetServerClientPtr client)
{
    virIdentityPtr ret = NULL;
    virObjectLock(client);
    if (!client->identity)
        client->identity = virNetServerClientCreateIdentity(client);
    if (client->identity)
        ret = virObjectRef(client->identity);
    virObjectUnlock(client);
    return ret;
}


811 812
int virNetServerClientGetSELinuxContext(virNetServerClientPtr client,
                                        char **context)
813 814 815 816 817
{
    int ret = 0;
    *context = NULL;
    virObjectLock(client);
    if (client->sock)
818
        ret = virNetSocketGetSELinuxContext(client->sock, context);
819 820 821 822 823
    virObjectUnlock(client);
    return ret;
}


824 825 826
bool virNetServerClientIsSecure(virNetServerClientPtr client)
{
    bool secure = false;
827
    virObjectLock(client);
828
#if WITH_GNUTLS
829 830
    if (client->tls)
        secure = true;
831
#endif
832
#if WITH_SASL
833 834 835
    if (client->sasl)
        secure = true;
#endif
836
    if (client->sock && virNetSocketIsLocal(client->sock))
837
        secure = true;
838
    virObjectUnlock(client);
839 840 841 842
    return secure;
}


843
#if WITH_SASL
844 845 846 847 848 849 850 851
void virNetServerClientSetSASLSession(virNetServerClientPtr client,
                                      virNetSASLSessionPtr sasl)
{
    /* We don't set the sasl session on the socket here
     * because we need to send out the auth confirmation
     * in the clear. Only once we complete the next 'tx'
     * operation do we switch to SASL mode
     */
852
    virObjectLock(client);
853
    client->sasl = virObjectRef(sasl);
854
    virObjectUnlock(client);
855
}
856 857 858 859 860 861 862 863 864 865


virNetSASLSessionPtr virNetServerClientGetSASLSession(virNetServerClientPtr client)
{
    virNetSASLSessionPtr sasl;
    virObjectLock(client);
    sasl = client->sasl;
    virObjectUnlock(client);
    return sasl;
}
866 867 868 869 870 871 872 873 874

bool virNetServerClientHasSASLSession(virNetServerClientPtr client)
{
    bool has = false;
    virObjectLock(client);
    has = !!client->sasl;
    virObjectUnlock(client);
    return has;
}
875 876 877 878 879 880
#endif


void *virNetServerClientGetPrivateData(virNetServerClientPtr client)
{
    void *data;
881
    virObjectLock(client);
882
    data = client->privateData;
883
    virObjectUnlock(client);
884 885 886 887
    return data;
}


888 889 890
void virNetServerClientSetCloseHook(virNetServerClientPtr client,
                                    virNetServerClientCloseFunc cf)
{
891
    virObjectLock(client);
892
    client->privateDataCloseFunc = cf;
893
    virObjectUnlock(client);
894 895 896
}


897 898 899 900
void virNetServerClientSetDispatcher(virNetServerClientPtr client,
                                     virNetServerClientDispatchFunc func,
                                     void *opaque)
{
901
    virObjectLock(client);
902 903
    client->dispatchFunc = func;
    client->dispatchOpaque = opaque;
904
    virObjectUnlock(client);
905 906 907 908 909
}


const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
{
910 911
    if (!client->sock)
        return NULL;
912 913 914 915 916 917
    return virNetSocketLocalAddrString(client->sock);
}


const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
{
918 919
    if (!client->sock)
        return NULL;
920 921 922
    return virNetSocketRemoteAddrString(client->sock);
}

923 924 925 926 927 928 929 930 931 932 933 934 935
char *virNetServerClientLocalAddrFormatSASL(virNetServerClientPtr client)
{
    if (!client->sock)
        return NULL;
    return virNetSocketLocalAddrFormatSASL(client->sock);
}

char *virNetServerClientRemoteAddrFormatSASL(virNetServerClientPtr client)
{
    if (!client->sock)
        return NULL;
    return virNetSocketRemoteAddrFormatSASL(client->sock);
}
936

937
void virNetServerClientDispose(void *obj)
938
{
939
    virNetServerClientPtr client = obj;
940

941 942 943
    PROBE(RPC_SERVER_CLIENT_DISPOSE,
          "client=%p", client);

944 945 946 947
    if (client->privateData &&
        client->privateDataFreeFunc)
        client->privateDataFreeFunc(client->privateData);

948 949
    virObjectUnref(client->identity);

950
#if WITH_SASL
951
    virObjectUnref(client->sasl);
952
#endif
M
Michal Privoznik 已提交
953 954
    if (client->sockTimer > 0)
        virEventRemoveTimeout(client->sockTimer);
955
#if WITH_GNUTLS
956 957
    virObjectUnref(client->tls);
    virObjectUnref(client->tlsCtxt);
958
#endif
959
    virObjectUnref(client->sock);
960 961 962 963 964 965 966 967 968 969 970 971 972
}


/*
 *
 * We don't free stuff here, merely disconnect the client's
 * network socket & resources.
 *
 * Full free of the client is done later in a safe point
 * where it can be guaranteed it is no longer in use
 */
void virNetServerClientClose(virNetServerClientPtr client)
{
973
    virNetServerClientCloseFunc cf;
974
    virKeepAlivePtr ka;
975

976
    virObjectLock(client);
977
    VIR_DEBUG("client=%p", client);
978
    if (!client->sock) {
979
        virObjectUnlock(client);
980 981 982
        return;
    }

983 984 985 986
    if (client->keepalive) {
        virKeepAliveStop(client->keepalive);
        ka = client->keepalive;
        client->keepalive = NULL;
987
        virObjectRef(client);
988
        virObjectUnlock(client);
989
        virObjectUnref(ka);
990
        virObjectLock(client);
991
        virObjectUnref(client);
992 993
    }

994 995
    if (client->privateDataCloseFunc) {
        cf = client->privateDataCloseFunc;
996
        virObjectRef(client);
997
        virObjectUnlock(client);
998
        (cf)(client);
999
        virObjectLock(client);
1000
        virObjectUnref(client);
1001 1002
    }

1003 1004 1005 1006 1007 1008
    /* Do now, even though we don't close the socket
     * until end, to ensure we don't get invoked
     * again due to tls shutdown */
    if (client->sock)
        virNetSocketRemoveIOCallback(client->sock);

1009
#if WITH_GNUTLS
1010
    if (client->tls) {
1011
        virObjectUnref(client->tls);
1012 1013
        client->tls = NULL;
    }
1014
#endif
1015
    client->wantClose = true;
1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027

    while (client->rx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->rx);
        virNetMessageFree(msg);
    }
    while (client->tx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->tx);
        virNetMessageFree(msg);
    }

1028
    if (client->sock) {
1029
        virObjectUnref(client->sock);
1030 1031 1032
        client->sock = NULL;
    }

1033
    virObjectUnlock(client);
1034 1035 1036 1037 1038 1039
}


bool virNetServerClientIsClosed(virNetServerClientPtr client)
{
    bool closed;
1040
    virObjectLock(client);
1041
    closed = client->sock == NULL ? true : false;
1042
    virObjectUnlock(client);
1043 1044 1045
    return closed;
}

1046 1047
void virNetServerClientDelayedClose(virNetServerClientPtr client)
{
1048
    virObjectLock(client);
1049
    client->delayedClose = true;
1050
    virObjectUnlock(client);
1051 1052 1053
}

void virNetServerClientImmediateClose(virNetServerClientPtr client)
1054
{
1055
    virObjectLock(client);
1056
    client->wantClose = true;
1057
    virObjectUnlock(client);
1058 1059 1060 1061 1062
}

bool virNetServerClientWantClose(virNetServerClientPtr client)
{
    bool wantClose;
1063
    virObjectLock(client);
1064
    wantClose = client->wantClose;
1065
    virObjectUnlock(client);
1066 1067 1068 1069 1070 1071
    return wantClose;
}


int virNetServerClientInit(virNetServerClientPtr client)
{
1072
    virObjectLock(client);
1073

1074
#if WITH_GNUTLS
1075
    if (!client->tlsCtxt) {
1076
#endif
1077 1078 1079
        /* Plain socket, so prepare to read first message */
        if (virNetServerClientRegisterEvent(client) < 0)
            goto error;
1080
#if WITH_GNUTLS
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108
    } else {
        int ret;

        if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
                                                NULL)))
            goto error;

        virNetSocketSetTLSSession(client->sock,
                                  client->tls);

        /* Begin the TLS handshake. */
        ret = virNetTLSSessionHandshake(client->tls);
        if (ret == 0) {
            /* Unlikely, but ...  Next step is to check the certificate. */
            if (virNetServerClientCheckAccess(client) < 0)
                goto error;

            /* Handshake & cert check OK,  so prepare to read first message */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else if (ret > 0) {
            /* Most likely, need to do more handshake data */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else {
            goto error;
        }
    }
1109
#endif
1110

1111
    virObjectUnlock(client);
1112 1113
    return 0;

1114
 error:
1115
    client->wantClose = true;
1116
    virObjectUnlock(client);
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134
    return -1;
}



/*
 * Read data into buffer using wire decoding (plain or TLS)
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientRead(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->rx->bufferLength <= client->rx->bufferOffset) {
1135 1136 1137
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->rx->bufferLength - client->rx->bufferOffset));
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158
        client->wantClose = true;
        return -1;
    }

    ret = virNetSocketRead(client->sock,
                           client->rx->buffer + client->rx->bufferOffset,
                           client->rx->bufferLength - client->rx->bufferOffset);

    if (ret <= 0)
        return ret;

    client->rx->bufferOffset += ret;
    return ret;
}


/*
 * Read data until we get a complete message to process
 */
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
1159
 readmore:
1160 1161 1162 1163 1164
    if (client->rx->nfds == 0) {
        if (virNetServerClientRead(client) < 0) {
            client->wantClose = true;
            return; /* Error */
        }
1165 1166 1167 1168 1169 1170 1171
    }

    if (client->rx->bufferOffset < client->rx->bufferLength)
        return; /* Still not read enough */

    /* Either done with length word header */
    if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) {
1172 1173
        if (virNetMessageDecodeLength(client->rx) < 0) {
            client->wantClose = true;
1174
            return;
1175
        }
1176 1177 1178 1179 1180 1181 1182 1183 1184

        virNetServerClientUpdateEvent(client);

        /* Try and read payload immediately instead of going back
           into poll() because chances are the data is already
           waiting for us */
        goto readmore;
    } else {
        /* Grab the completed message */
1185
        virNetMessagePtr msg = client->rx;
1186
        virNetMessagePtr response = NULL;
1187
        virNetServerClientFilterPtr filter;
1188
        size_t i;
1189 1190 1191

        /* Decode the header so we can use it for routing decisions */
        if (virNetMessageDecodeHeader(msg) < 0) {
1192
            virNetMessageQueueServe(&client->rx);
1193 1194 1195 1196 1197
            virNetMessageFree(msg);
            client->wantClose = true;
            return;
        }

1198 1199
        /* Now figure out if we need to read more data to get some
         * file descriptors */
1200 1201 1202
        if (msg->header.type == VIR_NET_CALL_WITH_FDS) {
            if (msg->nfds == 0 &&
                virNetMessageDecodeNumFDs(msg) < 0) {
1203
                virNetMessageQueueServe(&client->rx);
1204 1205
                virNetMessageFree(msg);
                client->wantClose = true;
1206
                return; /* Error */
1207
            }
1208

1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231
            /* Try getting the file descriptors (may fail if blocking) */
            for (i = msg->donefds; i < msg->nfds; i++) {
                int rv;
                if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
                    virNetMessageQueueServe(&client->rx);
                    virNetMessageFree(msg);
                    client->wantClose = true;
                    return;
                }
                if (rv == 0) /* Blocking */
                    break;
                msg->donefds++;
            }

            /* Need to poll() until FDs arrive */
            if (msg->donefds < msg->nfds) {
                /* Because DecodeHeader/NumFDs reset bufferOffset, we
                 * put it back to what it was, so everything works
                 * again next time we run this method
                 */
                client->rx->bufferOffset = client->rx->bufferLength;
                return;
            }
1232 1233
        }

1234 1235
        /* Definitely finished reading, so remove from queue */
        virNetMessageQueueServe(&client->rx);
1236 1237 1238 1239 1240 1241
        PROBE(RPC_SERVER_CLIENT_MSG_RX,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);

1242 1243 1244 1245 1246 1247 1248 1249 1250 1251
        if (virKeepAliveCheckMessage(client->keepalive, msg, &response)) {
            virNetMessageFree(msg);
            client->nrequests--;
            msg = NULL;

            if (response &&
                virNetServerClientSendMessageLocked(client, response) < 0)
                virNetMessageFree(response);
        }

1252
        /* Maybe send off for queue against a filter */
1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267
        if (msg) {
            filter = client->filters;
            while (filter) {
                int ret = filter->func(client, msg, filter->opaque);
                if (ret < 0) {
                    virNetMessageFree(msg);
                    msg = NULL;
                    if (ret < 0)
                        client->wantClose = true;
                    break;
                }
                if (ret > 0) {
                    msg = NULL;
                    break;
                }
1268

1269 1270
                filter = filter->next;
            }
1271 1272 1273 1274
        }

        /* Send off to for normal dispatch to workers */
        if (msg) {
1275
            virObjectRef(client);
1276 1277 1278 1279
            if (!client->dispatchFunc ||
                client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) {
                virNetMessageFree(msg);
                client->wantClose = true;
1280
                virObjectUnref(client);
1281 1282 1283 1284 1285 1286
                return;
            }
        }

        /* Possibly need to create another receive buffer */
        if (client->nrequests < client->nrequests_max) {
1287
            if (!(client->rx = virNetMessageNew(true))) {
1288
                client->wantClose = true;
E
Eric Blake 已提交
1289 1290
            } else {
                client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
1291 1292 1293 1294 1295 1296
                if (VIR_ALLOC_N(client->rx->buffer,
                                client->rx->bufferLength) < 0) {
                    client->wantClose = true;
                } else {
                    client->nrequests++;
                }
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316
            }
        }
        virNetServerClientUpdateEvent(client);
    }
}


/*
 * Send client->tx using no encoding
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientWrite(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->tx->bufferLength < client->tx->bufferOffset) {
1317 1318 1319
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->tx->bufferLength - client->tx->bufferOffset));
1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345
        client->wantClose = true;
        return -1;
    }

    if (client->tx->bufferLength == client->tx->bufferOffset)
        return 1;

    ret = virNetSocketWrite(client->sock,
                            client->tx->buffer + client->tx->bufferOffset,
                            client->tx->bufferLength - client->tx->bufferOffset);
    if (ret <= 0)
        return ret; /* -1 error, 0 = egain */

    client->tx->bufferOffset += ret;
    return ret;
}


/*
 * Process all queued client->tx messages until
 * we would block on I/O
 */
static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
    while (client->tx) {
1346 1347 1348 1349 1350 1351 1352 1353 1354
        if (client->tx->bufferOffset < client->tx->bufferLength) {
            ssize_t ret;
            ret = virNetServerClientWrite(client);
            if (ret < 0) {
                client->wantClose = true;
                return;
            }
            if (ret == 0)
                return; /* Would block on write EAGAIN */
1355 1356 1357 1358
        }

        if (client->tx->bufferOffset == client->tx->bufferLength) {
            virNetMessagePtr msg;
1359 1360
            size_t i;

1361
            for (i = client->tx->donefds; i < client->tx->nfds; i++) {
1362 1363
                int rv;
                if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
1364 1365 1366
                    client->wantClose = true;
                    return;
                }
1367 1368 1369
                if (rv == 0) /* Blocking */
                    return;
                client->tx->donefds++;
1370 1371
            }

1372
#if WITH_SASL
1373 1374 1375 1376 1377
            /* Completed this 'tx' operation, so now read for all
             * future rx/tx to be under a SASL SSF layer
             */
            if (client->sasl) {
                virNetSocketSetSASLSession(client->sock, client->sasl);
1378
                virObjectUnref(client->sasl);
1379 1380 1381 1382 1383 1384 1385
                client->sasl = NULL;
            }
#endif

            /* Get finished msg from head of tx queue */
            msg = virNetMessageQueueServe(&client->tx);

1386
            if (msg->tracked) {
1387 1388 1389 1390 1391
                client->nrequests--;
                /* See if the recv queue is currently throttled */
                if (!client->rx &&
                    client->nrequests < client->nrequests_max) {
                    /* Ready to recv more messages */
1392
                    virNetMessageClear(msg);
1393 1394 1395 1396 1397
                    msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
                    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
                        virNetMessageFree(msg);
                        return;
                    }
1398 1399 1400 1401 1402 1403 1404 1405 1406
                    client->rx = msg;
                    msg = NULL;
                    client->nrequests++;
                }
            }

            virNetMessageFree(msg);

            virNetServerClientUpdateEvent(client);
1407 1408 1409

            if (client->delayedClose)
                client->wantClose = true;
1410 1411 1412 1413
         }
    }
}

1414

1415
#if WITH_GNUTLS
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431
static void
virNetServerClientDispatchHandshake(virNetServerClientPtr client)
{
    int ret;
    /* Continue the handshake. */
    ret = virNetTLSSessionHandshake(client->tls);
    if (ret == 0) {
        /* Finished.  Next step is to check the certificate. */
        if (virNetServerClientCheckAccess(client) < 0)
            client->wantClose = true;
        else
            virNetServerClientUpdateEvent(client);
    } else if (ret > 0) {
        /* Carry on waiting for more handshake. Update
           the events just in case handshake data flow
           direction has changed */
1432
        virNetServerClientUpdateEvent(client);
1433 1434 1435 1436 1437
    } else {
        /* Fatal error in handshake */
        client->wantClose = true;
    }
}
1438
#endif
1439 1440 1441 1442 1443 1444

static void
virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque)
{
    virNetServerClientPtr client = opaque;

1445
    virObjectLock(client);
1446 1447 1448

    if (client->sock != sock) {
        virNetSocketRemoveIOCallback(sock);
1449
        virObjectUnlock(client);
1450 1451 1452 1453 1454
        return;
    }

    if (events & (VIR_EVENT_HANDLE_WRITABLE |
                  VIR_EVENT_HANDLE_READABLE)) {
1455
#if WITH_GNUTLS
1456 1457 1458 1459 1460
        if (client->tls &&
            virNetTLSSessionGetHandshakeStatus(client->tls) !=
            VIR_NET_TLS_HANDSHAKE_COMPLETE) {
            virNetServerClientDispatchHandshake(client);
        } else {
1461
#endif
1462 1463
            if (events & VIR_EVENT_HANDLE_WRITABLE)
                virNetServerClientDispatchWrite(client);
M
Michal Privoznik 已提交
1464 1465
            if (events & VIR_EVENT_HANDLE_READABLE &&
                client->rx)
1466
                virNetServerClientDispatchRead(client);
1467
#if WITH_GNUTLS
1468
        }
1469
#endif
1470 1471 1472 1473 1474 1475 1476 1477
    }

    /* NB, will get HANGUP + READABLE at same time upon
     * disconnect */
    if (events & (VIR_EVENT_HANDLE_ERROR |
                  VIR_EVENT_HANDLE_HANGUP))
        client->wantClose = true;

1478
    virObjectUnlock(client);
1479 1480 1481
}


1482 1483 1484
static int
virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                    virNetMessagePtr msg)
1485 1486 1487 1488 1489
{
    int ret = -1;
    VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu",
              msg, msg->header.proc,
              msg->bufferLength, msg->bufferOffset);
1490

1491
    msg->donefds = 0;
1492
    if (client->sock && !client->wantClose) {
1493 1494 1495 1496 1497
        PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);
1498 1499 1500 1501 1502 1503
        virNetMessageQueuePush(&client->tx, msg);

        virNetServerClientUpdateEvent(client);
        ret = 0;
    }

1504 1505 1506 1507 1508 1509 1510 1511
    return ret;
}

int virNetServerClientSendMessage(virNetServerClientPtr client,
                                  virNetMessagePtr msg)
{
    int ret;

1512
    virObjectLock(client);
1513
    ret = virNetServerClientSendMessageLocked(client, msg);
1514
    virObjectUnlock(client);
1515

1516 1517 1518 1519 1520 1521 1522
    return ret;
}


bool virNetServerClientNeedAuth(virNetServerClientPtr client)
{
    bool need = false;
1523
    virObjectLock(client);
1524
    if (client->auth)
1525
        need = true;
1526
    virObjectUnlock(client);
1527 1528
    return need;
}
1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552


static void
virNetServerClientKeepAliveDeadCB(void *opaque)
{
    virNetServerClientImmediateClose(opaque);
}

static int
virNetServerClientKeepAliveSendCB(void *opaque,
                                  virNetMessagePtr msg)
{
    return virNetServerClientSendMessage(opaque, msg);
}


int
virNetServerClientInitKeepAlive(virNetServerClientPtr client,
                                int interval,
                                unsigned int count)
{
    virKeepAlivePtr ka;
    int ret = -1;

1553
    virObjectLock(client);
1554 1555 1556 1557

    if (!(ka = virKeepAliveNew(interval, count, client,
                               virNetServerClientKeepAliveSendCB,
                               virNetServerClientKeepAliveDeadCB,
1558
                               virObjectFreeCallback)))
1559 1560
        goto cleanup;
    /* keepalive object has a reference to client */
1561
    virObjectRef(client);
1562 1563 1564

    client->keepalive = ka;

1565
 cleanup:
1566
    virObjectUnlock(client);
1567 1568 1569 1570 1571 1572 1573

    return ret;
}

int
virNetServerClientStartKeepAlive(virNetServerClientPtr client)
{
1574 1575
    int ret = -1;

1576
    virObjectLock(client);
1577 1578 1579 1580

    /* The connection might have been closed before we got here and thus the
     * keepalive object could have been removed too.
     */
1581
    if (!client->keepalive) {
1582 1583 1584 1585 1586
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("connection not open"));
        goto cleanup;
    }

1587
    ret = virKeepAliveStart(client->keepalive, 0, 0);
1588

1589
 cleanup:
1590
    virObjectUnlock(client);
1591 1592
    return ret;
}
1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614

int
virNetServerClientGetTransport(virNetServerClientPtr client)
{
    int ret = -1;

    virObjectLock(client);

    if (client->sock && virNetSocketIsLocal(client->sock))
        ret = VIR_CLIENT_TRANS_UNIX;
    else
        ret = VIR_CLIENT_TRANS_TCP;

#ifdef WITH_GNUTLS
    if (client->tls)
        ret = VIR_CLIENT_TRANS_TLS;
#endif

    virObjectUnlock(client);

    return ret;
}