virnetserverclient.c 44.5 KB
Newer Older
1 2 3
/*
 * virnetserverclient.c: generic network RPC server client
 *
4
 * Copyright (C) 2006-2014 Red Hat, Inc.
5 6 7 8 9 10 11 12 13 14 15 16 17
 * Copyright (C) 2006 Daniel P. Berrange
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
18
 * License along with this library.  If not, see
O
Osier Yang 已提交
19
 * <http://www.gnu.org/licenses/>.
20 21 22 23 24 25
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */

#include <config.h>

26
#include "internal.h"
27
#if WITH_SASL
28 29 30
# include <sasl/sasl.h>
#endif

31
#include "virnetserver.h"
32 33
#include "virnetserverclient.h"

34
#include "virlog.h"
35
#include "virerror.h"
36
#include "viralloc.h"
37
#include "virthread.h"
38
#include "virkeepalive.h"
39
#include "virprobe.h"
40 41
#include "virstring.h"
#include "virutil.h"
42 43 44

#define VIR_FROM_THIS VIR_FROM_RPC

45 46
VIR_LOG_INIT("rpc.netserverclient");

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/* Allow for filtering of incoming messages to a custom
 * dispatch processing queue, instead of the workers.
 * This allows for certain types of messages to be handled
 * strictly "in order"
 */

typedef struct _virNetServerClientFilter virNetServerClientFilter;
typedef virNetServerClientFilter *virNetServerClientFilterPtr;

struct _virNetServerClientFilter {
    int id;
    virNetServerClientFilterFunc func;
    void *opaque;

    virNetServerClientFilterPtr next;
};


struct _virNetServerClient
{
67
    virObjectLockable parent;
68

69
    unsigned long long id;
70
    bool wantClose;
71
    bool delayedClose;
72 73 74
    virNetSocketPtr sock;
    int auth;
    bool readonly;
75
#if WITH_GNUTLS
76 77
    virNetTLSContextPtr tlsCtxt;
    virNetTLSSessionPtr tls;
78
#endif
79
#if WITH_SASL
80 81
    virNetSASLSessionPtr sasl;
#endif
M
Michal Privoznik 已提交
82 83
    int sockTimer; /* Timer to be fired upon cached data,
                    * so we jump out from poll() immediately */
84

85 86 87

    virIdentityPtr identity;

88 89 90 91 92
    /* Connection timestamp, i.e. when a client connected to the daemon (UTC).
     * For old clients restored by post-exec-restart, which did not have this
     * attribute, value of 0 (epoch time) is used to indicate we have no
     * information about their connection time.
     */
93
    long long conn_time;
94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    /* Count of messages in the 'tx' queue,
     * and the server worker pool queue
     * ie RPC calls in progress. Does not count
     * async events which are not used for
     * throttling calculations */
    size_t nrequests;
    size_t nrequests_max;
    /* Zero or one messages being received. Zero if
     * nrequests >= max_clients and throttling */
    virNetMessagePtr rx;
    /* Zero or many messages waiting for transmit
     * back to client, including async events */
    virNetMessagePtr tx;

    /* Filters to capture messages that would otherwise
     * end up on the 'dx' queue */
    virNetServerClientFilterPtr filters;
    int nextFilterID;

    virNetServerClientDispatchFunc dispatchFunc;
    void *dispatchOpaque;

    void *privateData;
118
    virFreeCallback privateDataFreeFunc;
119
    virNetServerClientPrivPreExecRestart privateDataPreExecRestart;
120
    virNetServerClientCloseFunc privateDataCloseFunc;
121 122

    virKeepAlivePtr keepalive;
123 124 125
};


126 127 128 129 130
static virClassPtr virNetServerClientClass;
static void virNetServerClientDispose(void *obj);

static int virNetServerClientOnceInit(void)
{
131
    if (!(virNetServerClientClass = virClassNew(virClassForObjectLockable(),
132
                                                "virNetServerClient",
133 134 135 136 137 138 139 140 141 142
                                                sizeof(virNetServerClient),
                                                virNetServerClientDispose)))
        return -1;

    return 0;
}

VIR_ONCE_GLOBAL_INIT(virNetServerClient)


143 144
static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque);
static void virNetServerClientUpdateEvent(virNetServerClientPtr client);
M
Michal Privoznik 已提交
145
static void virNetServerClientDispatchRead(virNetServerClientPtr client);
146 147
static int virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                               virNetMessagePtr msg);
148 149 150 151 152

/*
 * @client: a locked client object
 */
static int
153 154
virNetServerClientCalculateHandleMode(virNetServerClientPtr client)
{
155 156 157 158
    int mode = 0;


    VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p",
159
#ifdef WITH_GNUTLS
160 161
              client->tls,
              client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1,
162 163 164
#else
              NULL, -1,
#endif
165 166 167 168 169
              client->rx,
              client->tx);
    if (!client->sock || client->wantClose)
        return 0;

170
#if WITH_GNUTLS
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
    if (client->tls) {
        switch (virNetTLSSessionGetHandshakeStatus(client->tls)) {
        case VIR_NET_TLS_HANDSHAKE_RECVING:
            mode |= VIR_EVENT_HANDLE_READABLE;
            break;
        case VIR_NET_TLS_HANDSHAKE_SENDING:
            mode |= VIR_EVENT_HANDLE_WRITABLE;
            break;
        default:
        case VIR_NET_TLS_HANDSHAKE_COMPLETE:
            if (client->rx)
                mode |= VIR_EVENT_HANDLE_READABLE;
            if (client->tx)
                mode |= VIR_EVENT_HANDLE_WRITABLE;
        }
    } else {
187
#endif
188 189
        /* If there is a message on the rx queue, and
         * we're not in middle of a delayedClose, then
190
         * we're wanting more input */
191
        if (client->rx && !client->delayedClose)
192 193 194 195 196 197
            mode |= VIR_EVENT_HANDLE_READABLE;

        /* If there are one or more messages to send back to client,
           then monitor for writability on socket */
        if (client->tx)
            mode |= VIR_EVENT_HANDLE_WRITABLE;
198
#if WITH_GNUTLS
199
    }
200
#endif
201
    VIR_DEBUG("mode=%o", mode);
202 203 204 205 206 207 208 209 210 211 212
    return mode;
}

/*
 * @server: a locked or unlocked server object
 * @client: a locked client object
 */
static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
{
    int mode = virNetServerClientCalculateHandleMode(client);

213 214 215 216
    if (!client->sock)
        return -1;

    virObjectRef(client);
217
    VIR_DEBUG("Registering client event callback %d", mode);
218
    if (virNetSocketAddIOCallback(client->sock,
219 220
                                  mode,
                                  virNetServerClientDispatchEvent,
221
                                  client,
222 223
                                  virObjectFreeCallback) < 0) {
        virObjectUnref(client);
224
        return -1;
225
    }
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

    return 0;
}

/*
 * @client: a locked client object
 */
static void virNetServerClientUpdateEvent(virNetServerClientPtr client)
{
    int mode;

    if (!client->sock)
        return;

    mode = virNetServerClientCalculateHandleMode(client);

    virNetSocketUpdateIOCallback(client->sock, mode);
M
Michal Privoznik 已提交
243 244 245

    if (client->rx && virNetSocketHasCachedData(client->sock))
        virEventUpdateTimeout(client->sockTimer, 0);
246 247 248
}


249 250 251
int virNetServerClientAddFilter(virNetServerClientPtr client,
                                virNetServerClientFilterFunc func,
                                void *opaque)
252 253
{
    virNetServerClientFilterPtr filter;
254
    virNetServerClientFilterPtr *place;
255
    int ret;
256

257
    if (VIR_ALLOC(filter) < 0)
258
        return -1;
259

260
    virObjectLock(client);
261

262 263 264 265
    filter->id = client->nextFilterID++;
    filter->func = func;
    filter->opaque = opaque;

266 267 268 269
    place = &client->filters;
    while (*place)
        place = &(*place)->next;
    *place = filter;
270 271 272

    ret = filter->id;

273
    virObjectUnlock(client);
274

275 276 277
    return ret;
}

278 279
void virNetServerClientRemoveFilter(virNetServerClientPtr client,
                                    int filterID)
280 281 282
{
    virNetServerClientFilterPtr tmp, prev;

283
    virObjectLock(client);
284

285 286 287 288 289 290 291 292 293 294 295 296
    prev = NULL;
    tmp = client->filters;
    while (tmp) {
        if (tmp->id == filterID) {
            if (prev)
                prev->next = tmp->next;
            else
                client->filters = tmp->next;

            VIR_FREE(tmp);
            break;
        }
E
Eric Blake 已提交
297
        prev = tmp;
298 299 300
        tmp = tmp->next;
    }

301
    virObjectUnlock(client);
302 303 304
}


305
#ifdef WITH_GNUTLS
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
/* Check the client's access. */
static int
virNetServerClientCheckAccess(virNetServerClientPtr client)
{
    virNetMessagePtr confirm;

    /* Verify client certificate. */
    if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0)
        return -1;

    if (client->tx) {
        VIR_DEBUG("client had unexpected data pending tx after access check");
        return -1;
    }

321
    if (!(confirm = virNetMessageNew(false)))
322 323 324 325 326 327 328
        return -1;

    /* Checks have succeeded.  Write a '\1' byte back to the client to
     * indicate this (otherwise the socket is abruptly closed).
     * (NB. The '\1' byte is sent in an encrypted record).
     */
    confirm->bufferLength = 1;
329 330 331 332
    if (VIR_ALLOC_N(confirm->buffer, confirm->bufferLength) < 0) {
        virNetMessageFree(confirm);
        return -1;
    }
333 334 335 336 337 338 339
    confirm->bufferOffset = 0;
    confirm->buffer[0] = '\1';

    client->tx = confirm;

    return 0;
}
340 341
#endif

342

M
Michal Privoznik 已提交
343 344 345 346
static void virNetServerClientSockTimerFunc(int timer,
                                            void *opaque)
{
    virNetServerClientPtr client = opaque;
347
    virObjectLock(client);
M
Michal Privoznik 已提交
348 349 350 351 352
    virEventUpdateTimeout(timer, -1);
    /* Although client->rx != NULL when this timer is enabled, it might have
     * changed since the client was unlocked in the meantime. */
    if (client->rx)
        virNetServerClientDispatchRead(client);
353
    virObjectUnlock(client);
M
Michal Privoznik 已提交
354 355
}

356

357
static virNetServerClientPtr
358 359
virNetServerClientNewInternal(unsigned long long id,
                              virNetSocketPtr sock,
360
                              int auth,
361
#ifdef WITH_GNUTLS
362 363
                              virNetTLSContextPtr tls,
#endif
364
                              bool readonly,
365
                              size_t nrequests_max,
366
                              long long timestamp)
367 368 369
{
    virNetServerClientPtr client;

370
    if (virNetServerClientInitialize() < 0)
371 372
        return NULL;

373
    if (!(client = virObjectLockableNew(virNetServerClientClass)))
374 375
        return NULL;

376
    client->id = id;
377
    client->sock = virObjectRef(sock);
378 379
    client->auth = auth;
    client->readonly = readonly;
380
#ifdef WITH_GNUTLS
381
    client->tlsCtxt = virObjectRef(tls);
382
#endif
383
    client->nrequests_max = nrequests_max;
384
    client->conn_time = timestamp;
385

M
Michal Privoznik 已提交
386 387 388 389 390
    client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
                                           client, NULL);
    if (client->sockTimer < 0)
        goto error;

391
    /* Prepare one for packet receive */
392
    if (!(client->rx = virNetMessageNew(true)))
393 394
        goto error;
    client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
395
    if (VIR_ALLOC_N(client->rx->buffer, client->rx->bufferLength) < 0)
396
        goto error;
397 398
    client->nrequests = 1;

399
    PROBE(RPC_SERVER_CLIENT_NEW,
400 401
          "client=%p sock=%p",
          client, client->sock);
402 403 404

    return client;

405
 error:
406
    virObjectUnref(client);
407 408 409 410
    return NULL;
}


411 412
virNetServerClientPtr virNetServerClientNew(unsigned long long id,
                                            virNetSocketPtr sock,
413 414 415
                                            int auth,
                                            bool readonly,
                                            size_t nrequests_max,
416
#ifdef WITH_GNUTLS
417
                                            virNetTLSContextPtr tls,
418
#endif
419
                                            virNetServerClientPrivNew privNew,
420
                                            virNetServerClientPrivPreExecRestart privPreExecRestart,
421 422 423 424
                                            virFreeCallback privFree,
                                            void *privOpaque)
{
    virNetServerClientPtr client;
425
    time_t now;
426

427
    VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth,
428
#ifdef WITH_GNUTLS
429 430 431 432 433
              tls
#else
              NULL
#endif
        );
434

435 436 437 438 439
    if ((now = time(NULL)) == (time_t) - 1) {
        virReportSystemError(errno, "%s", _("failed to get current time"));
        return NULL;
    }

440
    if (!(client = virNetServerClientNewInternal(id, sock, auth,
441
#ifdef WITH_GNUTLS
442 443
                                                 tls,
#endif
444 445
                                                 readonly, nrequests_max,
                                                 now)))
446 447 448 449 450 451 452 453
        return NULL;

    if (privNew) {
        if (!(client->privateData = privNew(client, privOpaque))) {
            virObjectUnref(client);
            return NULL;
        }
        client->privateDataFreeFunc = privFree;
454
        client->privateDataPreExecRestart = privPreExecRestart;
455 456 457 458 459 460
    }

    return client;
}


461 462 463 464
virNetServerClientPtr virNetServerClientNewPostExecRestart(virJSONValuePtr object,
                                                           virNetServerClientPrivNewPostExecRestart privNew,
                                                           virNetServerClientPrivPreExecRestart privPreExecRestart,
                                                           virFreeCallback privFree,
465 466
                                                           void *privOpaque,
                                                           void *opaque)
467 468 469 470 471 472 473
{
    virJSONValuePtr child;
    virNetServerClientPtr client = NULL;
    virNetSocketPtr sock;
    int auth;
    bool readonly;
    unsigned int nrequests_max;
474
    unsigned long long id;
475
    long long timestamp;
476 477 478 479 480 481 482 483 484 485 486 487

    if (virJSONValueObjectGetNumberInt(object, "auth", &auth) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing auth field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetBoolean(object, "readonly", &readonly) < 0) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing readonly field in JSON state document"));
        return NULL;
    }
    if (virJSONValueObjectGetNumberUint(object, "nrequests_max",
488
                                        &nrequests_max) < 0) {
489 490 491 492 493 494 495 496 497 498 499
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing nrequests_client_max field in JSON state document"));
        return NULL;
    }

    if (!(child = virJSONValueObjectGet(object, "sock"))) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Missing sock field in JSON state document"));
        return NULL;
    }

500 501 502 503
    if (!virJSONValueObjectHasKey(object, "id")) {
        /* no ID found in, a new one must be generated */
        id = virNetServerNextClientID((virNetServerPtr) opaque);
    } else {
504
        if (virJSONValueObjectGetNumberUlong(object, "id", &id) < 0) {
505 506 507 508 509 510
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("Malformed id field in JSON state document"));
        return NULL;
        }
    }

511 512 513
    if (!virJSONValueObjectHasKey(object, "conn_time")) {
        timestamp = 0;
    } else {
514
        if (virJSONValueObjectGetNumberLong(object, "conn_time", &timestamp) < 0) {
515 516 517 518 519 520 521
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Malformed conn_time field in JSON "
                             "state document"));
            return NULL;
        }
    }

522 523 524 525 526
    if (!(sock = virNetSocketNewPostExecRestart(child))) {
        virObjectUnref(sock);
        return NULL;
    }

527 528
    if (!(client = virNetServerClientNewInternal(id,
                                                 sock,
529
                                                 auth,
530
#ifdef WITH_GNUTLS
531 532
                                                 NULL,
#endif
533
                                                 readonly,
534 535
                                                 nrequests_max,
                                                 timestamp))) {
536 537 538 539 540 541 542 543 544 545 546
        virObjectUnref(sock);
        return NULL;
    }
    virObjectUnref(sock);

    if (privNew) {
        if (!(child = virJSONValueObjectGet(object, "privateData"))) {
            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                           _("Missing privateData field in JSON state document"));
            goto error;
        }
547
        if (!(client->privateData = privNew(client, child, privOpaque)))
548 549 550 551 552 553 554 555
            goto error;
        client->privateDataFreeFunc = privFree;
        client->privateDataPreExecRestart = privPreExecRestart;
    }


    return client;

556
 error:
557 558 559 560 561 562 563 564 565 566 567 568 569
    virObjectUnref(client);
    return NULL;
}


virJSONValuePtr virNetServerClientPreExecRestart(virNetServerClientPtr client)
{
    virJSONValuePtr object = virJSONValueNewObject();
    virJSONValuePtr child;

    if (!object)
        return NULL;

570
    virObjectLock(client);
571

572 573 574 575
    if (virJSONValueObjectAppendNumberUlong(object, "id",
                                            client->id) < 0)
        goto error;

576 577 578 579 580 581 582
    if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) < 0)
        goto error;
    if (virJSONValueObjectAppendBoolean(object, "readonly", client->readonly) < 0)
        goto error;
    if (virJSONValueObjectAppendNumberUint(object, "nrequests_max", client->nrequests_max) < 0)
        goto error;

583 584 585 586 587
    if (client->conn_time &&
        virJSONValueObjectAppendNumberLong(object, "conn_time",
                                           client->conn_time) < 0)
        goto error;

588 589 590 591 592 593 594 595
    if (!(child = virNetSocketPreExecRestart(client->sock)))
        goto error;

    if (virJSONValueObjectAppend(object, "sock", child) < 0) {
        virJSONValueFree(child);
        goto error;
    }

596 597 598
    if (client->privateData && client->privateDataPreExecRestart) {
        if (!(child = client->privateDataPreExecRestart(client, client->privateData)))
            goto error;
599

600 601 602 603
        if (virJSONValueObjectAppend(object, "privateData", child) < 0) {
            virJSONValueFree(child);
            goto error;
        }
604 605
    }

606
    virObjectUnlock(client);
607 608
    return object;

609
 error:
610
    virObjectUnlock(client);
611 612 613 614 615
    virJSONValueFree(object);
    return NULL;
}


616 617 618
int virNetServerClientGetAuth(virNetServerClientPtr client)
{
    int auth;
619
    virObjectLock(client);
620
    auth = client->auth;
621
    virObjectUnlock(client);
622 623 624
    return auth;
}

625 626 627 628 629 630 631
void virNetServerClientSetAuth(virNetServerClientPtr client, int auth)
{
    virObjectLock(client);
    client->auth = auth;
    virObjectUnlock(client);
}

632 633 634
bool virNetServerClientGetReadonly(virNetServerClientPtr client)
{
    bool readonly;
635
    virObjectLock(client);
636
    readonly = client->readonly;
637
    virObjectUnlock(client);
638 639 640
    return readonly;
}

641 642 643 644
unsigned long long virNetServerClientGetID(virNetServerClientPtr client)
{
    return client->id;
}
645

646 647 648 649 650
long long virNetServerClientGetTimestamp(virNetServerClientPtr client)
{
    return client->conn_time;
}

651
#ifdef WITH_GNUTLS
652 653 654
bool virNetServerClientHasTLSSession(virNetServerClientPtr client)
{
    bool has;
655
    virObjectLock(client);
656
    has = client->tls ? true : false;
657
    virObjectUnlock(client);
658 659 660
    return has;
}

661 662 663 664 665 666 667 668 669 670

virNetTLSSessionPtr virNetServerClientGetTLSSession(virNetServerClientPtr client)
{
    virNetTLSSessionPtr tls;
    virObjectLock(client);
    tls = client->tls;
    virObjectUnlock(client);
    return tls;
}

671 672 673
int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
{
    int size = 0;
674
    virObjectLock(client);
675 676
    if (client->tls)
        size = virNetTLSSessionGetKeySize(client->tls);
677
    virObjectUnlock(client);
678 679
    return size;
}
680
#endif
681 682 683

int virNetServerClientGetFD(virNetServerClientPtr client)
{
684
    int fd = -1;
685
    virObjectLock(client);
686 687
    if (client->sock)
        fd = virNetSocketGetFD(client->sock);
688
    virObjectUnlock(client);
689 690 691
    return fd;
}

692 693 694 695 696 697 698 699 700 701 702 703

bool virNetServerClientIsLocal(virNetServerClientPtr client)
{
    bool local = false;
    virObjectLock(client);
    if (client->sock)
        local = virNetSocketIsLocal(client->sock);
    virObjectUnlock(client);
    return local;
}


704
int virNetServerClientGetUNIXIdentity(virNetServerClientPtr client,
705 706
                                      uid_t *uid, gid_t *gid, pid_t *pid,
                                      unsigned long long *timestamp)
707
{
708
    int ret = -1;
709
    virObjectLock(client);
710
    if (client->sock)
711 712 713
        ret = virNetSocketGetUNIXIdentity(client->sock,
                                          uid, gid, pid,
                                          timestamp);
714
    virObjectUnlock(client);
715 716 717
    return ret;
}

718

719 720 721 722 723 724 725 726
static virIdentityPtr
virNetServerClientCreateIdentity(virNetServerClientPtr client)
{
    char *username = NULL;
    char *groupname = NULL;
    char *seccontext = NULL;
    virIdentityPtr ret = NULL;

727 728 729
    if (!(ret = virIdentityNew()))
        goto error;

730 731 732 733
    if (client->sock && virNetSocketIsLocal(client->sock)) {
        gid_t gid;
        uid_t uid;
        pid_t pid;
734 735 736 737
        unsigned long long timestamp;
        if (virNetSocketGetUNIXIdentity(client->sock,
                                        &uid, &gid, &pid,
                                        &timestamp) < 0)
738
            goto error;
739 740

        if (!(username = virGetUserName(uid)))
741 742 743 744 745 746
            goto error;
        if (virIdentitySetUNIXUserName(ret, username) < 0)
            goto error;
        if (virIdentitySetUNIXUserID(ret, uid) < 0)
            goto error;

747
        if (!(groupname = virGetGroupName(gid)))
748 749 750 751 752 753 754 755 756 757
            goto error;
        if (virIdentitySetUNIXGroupName(ret, groupname) < 0)
            goto error;
        if (virIdentitySetUNIXGroupID(ret, gid) < 0)
            goto error;

        if (virIdentitySetUNIXProcessID(ret, pid) < 0)
            goto error;
        if (virIdentitySetUNIXProcessTime(ret, timestamp) < 0)
            goto error;
758 759 760 761 762
    }

#if WITH_SASL
    if (client->sasl) {
        const char *identity = virNetSASLSessionGetIdentity(client->sasl);
763 764
        if (virIdentitySetSASLUserName(ret, identity) < 0)
            goto error;
765 766 767
    }
#endif

768
#if WITH_GNUTLS
769 770
    if (client->tls) {
        const char *identity = virNetTLSSessionGetX509DName(client->tls);
771 772
        if (virIdentitySetX509DName(ret, identity) < 0)
            goto error;
773
    }
774
#endif
775 776

    if (client->sock &&
777
        virNetSocketGetSELinuxContext(client->sock, &seccontext) < 0)
778
        goto error;
779
    if (seccontext &&
780
        virIdentitySetSELinuxContext(ret, seccontext) < 0)
781 782
        goto error;

783
 cleanup:
784 785 786 787 788
    VIR_FREE(username);
    VIR_FREE(groupname);
    VIR_FREE(seccontext);
    return ret;

789
 error:
790
    virObjectUnref(ret);
791
    ret = NULL;
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
    goto cleanup;
}


virIdentityPtr virNetServerClientGetIdentity(virNetServerClientPtr client)
{
    virIdentityPtr ret = NULL;
    virObjectLock(client);
    if (!client->identity)
        client->identity = virNetServerClientCreateIdentity(client);
    if (client->identity)
        ret = virObjectRef(client->identity);
    virObjectUnlock(client);
    return ret;
}


809 810
int virNetServerClientGetSELinuxContext(virNetServerClientPtr client,
                                        char **context)
811 812 813 814 815
{
    int ret = 0;
    *context = NULL;
    virObjectLock(client);
    if (client->sock)
816
        ret = virNetSocketGetSELinuxContext(client->sock, context);
817 818 819 820 821
    virObjectUnlock(client);
    return ret;
}


822 823 824
bool virNetServerClientIsSecure(virNetServerClientPtr client)
{
    bool secure = false;
825
    virObjectLock(client);
826
#if WITH_GNUTLS
827 828
    if (client->tls)
        secure = true;
829
#endif
830
#if WITH_SASL
831 832 833
    if (client->sasl)
        secure = true;
#endif
834
    if (client->sock && virNetSocketIsLocal(client->sock))
835
        secure = true;
836
    virObjectUnlock(client);
837 838 839 840
    return secure;
}


841
#if WITH_SASL
842 843 844 845 846 847 848 849
void virNetServerClientSetSASLSession(virNetServerClientPtr client,
                                      virNetSASLSessionPtr sasl)
{
    /* We don't set the sasl session on the socket here
     * because we need to send out the auth confirmation
     * in the clear. Only once we complete the next 'tx'
     * operation do we switch to SASL mode
     */
850
    virObjectLock(client);
851
    client->sasl = virObjectRef(sasl);
852
    virObjectUnlock(client);
853
}
854 855 856 857 858 859 860 861 862 863


virNetSASLSessionPtr virNetServerClientGetSASLSession(virNetServerClientPtr client)
{
    virNetSASLSessionPtr sasl;
    virObjectLock(client);
    sasl = client->sasl;
    virObjectUnlock(client);
    return sasl;
}
864 865 866 867 868 869 870 871 872

bool virNetServerClientHasSASLSession(virNetServerClientPtr client)
{
    bool has = false;
    virObjectLock(client);
    has = !!client->sasl;
    virObjectUnlock(client);
    return has;
}
873 874 875 876 877 878
#endif


void *virNetServerClientGetPrivateData(virNetServerClientPtr client)
{
    void *data;
879
    virObjectLock(client);
880
    data = client->privateData;
881
    virObjectUnlock(client);
882 883 884 885
    return data;
}


886 887 888
void virNetServerClientSetCloseHook(virNetServerClientPtr client,
                                    virNetServerClientCloseFunc cf)
{
889
    virObjectLock(client);
890
    client->privateDataCloseFunc = cf;
891
    virObjectUnlock(client);
892 893 894
}


895 896 897 898
void virNetServerClientSetDispatcher(virNetServerClientPtr client,
                                     virNetServerClientDispatchFunc func,
                                     void *opaque)
{
899
    virObjectLock(client);
900 901
    client->dispatchFunc = func;
    client->dispatchOpaque = opaque;
902
    virObjectUnlock(client);
903 904 905 906 907
}


const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
{
908 909
    if (!client->sock)
        return NULL;
910 911 912 913 914 915
    return virNetSocketLocalAddrString(client->sock);
}


const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
{
916 917
    if (!client->sock)
        return NULL;
918 919 920
    return virNetSocketRemoteAddrString(client->sock);
}

921 922 923 924 925 926
const char *virNetServerClientRemoteAddrStringURI(virNetServerClientPtr client)
{
    if (!client->sock)
        return NULL;
    return virNetSocketRemoteAddrStringURI(client->sock);
}
927

928
void virNetServerClientDispose(void *obj)
929
{
930
    virNetServerClientPtr client = obj;
931

932 933 934
    PROBE(RPC_SERVER_CLIENT_DISPOSE,
          "client=%p", client);

935 936 937 938
    if (client->privateData &&
        client->privateDataFreeFunc)
        client->privateDataFreeFunc(client->privateData);

939 940
    virObjectUnref(client->identity);

941
#if WITH_SASL
942
    virObjectUnref(client->sasl);
943
#endif
M
Michal Privoznik 已提交
944 945
    if (client->sockTimer > 0)
        virEventRemoveTimeout(client->sockTimer);
946
#if WITH_GNUTLS
947 948
    virObjectUnref(client->tls);
    virObjectUnref(client->tlsCtxt);
949
#endif
950
    virObjectUnref(client->sock);
951 952 953 954 955 956 957 958 959 960 961 962 963
}


/*
 *
 * We don't free stuff here, merely disconnect the client's
 * network socket & resources.
 *
 * Full free of the client is done later in a safe point
 * where it can be guaranteed it is no longer in use
 */
void virNetServerClientClose(virNetServerClientPtr client)
{
964
    virNetServerClientCloseFunc cf;
965
    virKeepAlivePtr ka;
966

967
    virObjectLock(client);
968
    VIR_DEBUG("client=%p", client);
969
    if (!client->sock) {
970
        virObjectUnlock(client);
971 972 973
        return;
    }

974 975 976 977
    if (client->keepalive) {
        virKeepAliveStop(client->keepalive);
        ka = client->keepalive;
        client->keepalive = NULL;
978
        virObjectRef(client);
979
        virObjectUnlock(client);
980
        virObjectUnref(ka);
981
        virObjectLock(client);
982
        virObjectUnref(client);
983 984
    }

985 986
    if (client->privateDataCloseFunc) {
        cf = client->privateDataCloseFunc;
987
        virObjectRef(client);
988
        virObjectUnlock(client);
989
        (cf)(client);
990
        virObjectLock(client);
991
        virObjectUnref(client);
992 993
    }

994 995 996 997 998 999
    /* Do now, even though we don't close the socket
     * until end, to ensure we don't get invoked
     * again due to tls shutdown */
    if (client->sock)
        virNetSocketRemoveIOCallback(client->sock);

1000
#if WITH_GNUTLS
1001
    if (client->tls) {
1002
        virObjectUnref(client->tls);
1003 1004
        client->tls = NULL;
    }
1005
#endif
1006
    client->wantClose = true;
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018

    while (client->rx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->rx);
        virNetMessageFree(msg);
    }
    while (client->tx) {
        virNetMessagePtr msg
            = virNetMessageQueueServe(&client->tx);
        virNetMessageFree(msg);
    }

1019
    if (client->sock) {
1020
        virObjectUnref(client->sock);
1021 1022 1023
        client->sock = NULL;
    }

1024
    virObjectUnlock(client);
1025 1026 1027 1028 1029 1030
}


bool virNetServerClientIsClosed(virNetServerClientPtr client)
{
    bool closed;
1031
    virObjectLock(client);
1032
    closed = client->sock == NULL ? true : false;
1033
    virObjectUnlock(client);
1034 1035 1036
    return closed;
}

1037 1038
void virNetServerClientDelayedClose(virNetServerClientPtr client)
{
1039
    virObjectLock(client);
1040
    client->delayedClose = true;
1041
    virObjectUnlock(client);
1042 1043 1044
}

void virNetServerClientImmediateClose(virNetServerClientPtr client)
1045
{
1046
    virObjectLock(client);
1047
    client->wantClose = true;
1048
    virObjectUnlock(client);
1049 1050 1051 1052 1053
}

bool virNetServerClientWantClose(virNetServerClientPtr client)
{
    bool wantClose;
1054
    virObjectLock(client);
1055
    wantClose = client->wantClose;
1056
    virObjectUnlock(client);
1057 1058 1059 1060 1061 1062
    return wantClose;
}


int virNetServerClientInit(virNetServerClientPtr client)
{
1063
    virObjectLock(client);
1064

1065
#if WITH_GNUTLS
1066
    if (!client->tlsCtxt) {
1067
#endif
1068 1069 1070
        /* Plain socket, so prepare to read first message */
        if (virNetServerClientRegisterEvent(client) < 0)
            goto error;
1071
#if WITH_GNUTLS
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099
    } else {
        int ret;

        if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
                                                NULL)))
            goto error;

        virNetSocketSetTLSSession(client->sock,
                                  client->tls);

        /* Begin the TLS handshake. */
        ret = virNetTLSSessionHandshake(client->tls);
        if (ret == 0) {
            /* Unlikely, but ...  Next step is to check the certificate. */
            if (virNetServerClientCheckAccess(client) < 0)
                goto error;

            /* Handshake & cert check OK,  so prepare to read first message */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else if (ret > 0) {
            /* Most likely, need to do more handshake data */
            if (virNetServerClientRegisterEvent(client) < 0)
                goto error;
        } else {
            goto error;
        }
    }
1100
#endif
1101

1102
    virObjectUnlock(client);
1103 1104
    return 0;

1105
 error:
1106
    client->wantClose = true;
1107
    virObjectUnlock(client);
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125
    return -1;
}



/*
 * Read data into buffer using wire decoding (plain or TLS)
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientRead(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->rx->bufferLength <= client->rx->bufferOffset) {
1126 1127 1128
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->rx->bufferLength - client->rx->bufferOffset));
1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149
        client->wantClose = true;
        return -1;
    }

    ret = virNetSocketRead(client->sock,
                           client->rx->buffer + client->rx->bufferOffset,
                           client->rx->bufferLength - client->rx->bufferOffset);

    if (ret <= 0)
        return ret;

    client->rx->bufferOffset += ret;
    return ret;
}


/*
 * Read data until we get a complete message to process
 */
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
1150
 readmore:
1151 1152 1153 1154 1155
    if (client->rx->nfds == 0) {
        if (virNetServerClientRead(client) < 0) {
            client->wantClose = true;
            return; /* Error */
        }
1156 1157 1158 1159 1160 1161 1162
    }

    if (client->rx->bufferOffset < client->rx->bufferLength)
        return; /* Still not read enough */

    /* Either done with length word header */
    if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) {
1163 1164
        if (virNetMessageDecodeLength(client->rx) < 0) {
            client->wantClose = true;
1165
            return;
1166
        }
1167 1168 1169 1170 1171 1172 1173 1174 1175

        virNetServerClientUpdateEvent(client);

        /* Try and read payload immediately instead of going back
           into poll() because chances are the data is already
           waiting for us */
        goto readmore;
    } else {
        /* Grab the completed message */
1176
        virNetMessagePtr msg = client->rx;
1177
        virNetMessagePtr response = NULL;
1178
        virNetServerClientFilterPtr filter;
1179
        size_t i;
1180 1181 1182

        /* Decode the header so we can use it for routing decisions */
        if (virNetMessageDecodeHeader(msg) < 0) {
1183
            virNetMessageQueueServe(&client->rx);
1184 1185 1186 1187 1188
            virNetMessageFree(msg);
            client->wantClose = true;
            return;
        }

1189 1190
        /* Now figure out if we need to read more data to get some
         * file descriptors */
1191 1192 1193
        if (msg->header.type == VIR_NET_CALL_WITH_FDS) {
            if (msg->nfds == 0 &&
                virNetMessageDecodeNumFDs(msg) < 0) {
1194
                virNetMessageQueueServe(&client->rx);
1195 1196
                virNetMessageFree(msg);
                client->wantClose = true;
1197
                return; /* Error */
1198
            }
1199

1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222
            /* Try getting the file descriptors (may fail if blocking) */
            for (i = msg->donefds; i < msg->nfds; i++) {
                int rv;
                if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
                    virNetMessageQueueServe(&client->rx);
                    virNetMessageFree(msg);
                    client->wantClose = true;
                    return;
                }
                if (rv == 0) /* Blocking */
                    break;
                msg->donefds++;
            }

            /* Need to poll() until FDs arrive */
            if (msg->donefds < msg->nfds) {
                /* Because DecodeHeader/NumFDs reset bufferOffset, we
                 * put it back to what it was, so everything works
                 * again next time we run this method
                 */
                client->rx->bufferOffset = client->rx->bufferLength;
                return;
            }
1223 1224
        }

1225 1226
        /* Definitely finished reading, so remove from queue */
        virNetMessageQueueServe(&client->rx);
1227 1228 1229 1230 1231 1232
        PROBE(RPC_SERVER_CLIENT_MSG_RX,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);

1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
        if (virKeepAliveCheckMessage(client->keepalive, msg, &response)) {
            virNetMessageFree(msg);
            client->nrequests--;
            msg = NULL;

            if (response &&
                virNetServerClientSendMessageLocked(client, response) < 0)
                virNetMessageFree(response);
        }

1243
        /* Maybe send off for queue against a filter */
1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258
        if (msg) {
            filter = client->filters;
            while (filter) {
                int ret = filter->func(client, msg, filter->opaque);
                if (ret < 0) {
                    virNetMessageFree(msg);
                    msg = NULL;
                    if (ret < 0)
                        client->wantClose = true;
                    break;
                }
                if (ret > 0) {
                    msg = NULL;
                    break;
                }
1259

1260 1261
                filter = filter->next;
            }
1262 1263 1264 1265
        }

        /* Send off to for normal dispatch to workers */
        if (msg) {
1266
            virObjectRef(client);
1267 1268 1269 1270
            if (!client->dispatchFunc ||
                client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) {
                virNetMessageFree(msg);
                client->wantClose = true;
1271
                virObjectUnref(client);
1272 1273 1274 1275 1276 1277
                return;
            }
        }

        /* Possibly need to create another receive buffer */
        if (client->nrequests < client->nrequests_max) {
1278
            if (!(client->rx = virNetMessageNew(true))) {
1279
                client->wantClose = true;
E
Eric Blake 已提交
1280 1281
            } else {
                client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
1282 1283 1284 1285 1286 1287
                if (VIR_ALLOC_N(client->rx->buffer,
                                client->rx->bufferLength) < 0) {
                    client->wantClose = true;
                } else {
                    client->nrequests++;
                }
1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307
            }
        }
        virNetServerClientUpdateEvent(client);
    }
}


/*
 * Send client->tx using no encoding
 *
 * Returns:
 *   -1 on error or EOF
 *    0 on EAGAIN
 *    n number of bytes
 */
static ssize_t virNetServerClientWrite(virNetServerClientPtr client)
{
    ssize_t ret;

    if (client->tx->bufferLength < client->tx->bufferOffset) {
1308 1309 1310
        virReportError(VIR_ERR_RPC,
                       _("unexpected zero/negative length request %lld"),
                       (long long int)(client->tx->bufferLength - client->tx->bufferOffset));
1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336
        client->wantClose = true;
        return -1;
    }

    if (client->tx->bufferLength == client->tx->bufferOffset)
        return 1;

    ret = virNetSocketWrite(client->sock,
                            client->tx->buffer + client->tx->bufferOffset,
                            client->tx->bufferLength - client->tx->bufferOffset);
    if (ret <= 0)
        return ret; /* -1 error, 0 = egain */

    client->tx->bufferOffset += ret;
    return ret;
}


/*
 * Process all queued client->tx messages until
 * we would block on I/O
 */
static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
    while (client->tx) {
1337 1338 1339 1340 1341 1342 1343 1344 1345
        if (client->tx->bufferOffset < client->tx->bufferLength) {
            ssize_t ret;
            ret = virNetServerClientWrite(client);
            if (ret < 0) {
                client->wantClose = true;
                return;
            }
            if (ret == 0)
                return; /* Would block on write EAGAIN */
1346 1347 1348 1349
        }

        if (client->tx->bufferOffset == client->tx->bufferLength) {
            virNetMessagePtr msg;
1350 1351
            size_t i;

1352
            for (i = client->tx->donefds; i < client->tx->nfds; i++) {
1353 1354
                int rv;
                if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
1355 1356 1357
                    client->wantClose = true;
                    return;
                }
1358 1359 1360
                if (rv == 0) /* Blocking */
                    return;
                client->tx->donefds++;
1361 1362
            }

1363
#if WITH_SASL
1364 1365 1366 1367 1368
            /* Completed this 'tx' operation, so now read for all
             * future rx/tx to be under a SASL SSF layer
             */
            if (client->sasl) {
                virNetSocketSetSASLSession(client->sock, client->sasl);
1369
                virObjectUnref(client->sasl);
1370 1371 1372 1373 1374 1375 1376
                client->sasl = NULL;
            }
#endif

            /* Get finished msg from head of tx queue */
            msg = virNetMessageQueueServe(&client->tx);

1377
            if (msg->tracked) {
1378 1379 1380 1381 1382
                client->nrequests--;
                /* See if the recv queue is currently throttled */
                if (!client->rx &&
                    client->nrequests < client->nrequests_max) {
                    /* Ready to recv more messages */
1383
                    virNetMessageClear(msg);
1384 1385 1386 1387 1388
                    msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
                    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
                        virNetMessageFree(msg);
                        return;
                    }
1389 1390 1391 1392 1393 1394 1395 1396 1397
                    client->rx = msg;
                    msg = NULL;
                    client->nrequests++;
                }
            }

            virNetMessageFree(msg);

            virNetServerClientUpdateEvent(client);
1398 1399 1400

            if (client->delayedClose)
                client->wantClose = true;
1401 1402 1403 1404
         }
    }
}

1405

1406
#if WITH_GNUTLS
1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422
static void
virNetServerClientDispatchHandshake(virNetServerClientPtr client)
{
    int ret;
    /* Continue the handshake. */
    ret = virNetTLSSessionHandshake(client->tls);
    if (ret == 0) {
        /* Finished.  Next step is to check the certificate. */
        if (virNetServerClientCheckAccess(client) < 0)
            client->wantClose = true;
        else
            virNetServerClientUpdateEvent(client);
    } else if (ret > 0) {
        /* Carry on waiting for more handshake. Update
           the events just in case handshake data flow
           direction has changed */
1423
        virNetServerClientUpdateEvent(client);
1424 1425 1426 1427 1428
    } else {
        /* Fatal error in handshake */
        client->wantClose = true;
    }
}
1429
#endif
1430 1431 1432 1433 1434 1435

static void
virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque)
{
    virNetServerClientPtr client = opaque;

1436
    virObjectLock(client);
1437 1438 1439

    if (client->sock != sock) {
        virNetSocketRemoveIOCallback(sock);
1440
        virObjectUnlock(client);
1441 1442 1443 1444 1445
        return;
    }

    if (events & (VIR_EVENT_HANDLE_WRITABLE |
                  VIR_EVENT_HANDLE_READABLE)) {
1446
#if WITH_GNUTLS
1447 1448 1449 1450 1451
        if (client->tls &&
            virNetTLSSessionGetHandshakeStatus(client->tls) !=
            VIR_NET_TLS_HANDSHAKE_COMPLETE) {
            virNetServerClientDispatchHandshake(client);
        } else {
1452
#endif
1453 1454
            if (events & VIR_EVENT_HANDLE_WRITABLE)
                virNetServerClientDispatchWrite(client);
M
Michal Privoznik 已提交
1455 1456
            if (events & VIR_EVENT_HANDLE_READABLE &&
                client->rx)
1457
                virNetServerClientDispatchRead(client);
1458
#if WITH_GNUTLS
1459
        }
1460
#endif
1461 1462 1463 1464 1465 1466 1467 1468
    }

    /* NB, will get HANGUP + READABLE at same time upon
     * disconnect */
    if (events & (VIR_EVENT_HANDLE_ERROR |
                  VIR_EVENT_HANDLE_HANGUP))
        client->wantClose = true;

1469
    virObjectUnlock(client);
1470 1471 1472
}


1473 1474 1475
static int
virNetServerClientSendMessageLocked(virNetServerClientPtr client,
                                    virNetMessagePtr msg)
1476 1477 1478 1479 1480
{
    int ret = -1;
    VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu",
              msg, msg->header.proc,
              msg->bufferLength, msg->bufferOffset);
1481

1482
    msg->donefds = 0;
1483
    if (client->sock && !client->wantClose) {
1484 1485 1486 1487 1488
        PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
              "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
              client, msg->bufferLength,
              msg->header.prog, msg->header.vers, msg->header.proc,
              msg->header.type, msg->header.status, msg->header.serial);
1489 1490 1491 1492 1493 1494
        virNetMessageQueuePush(&client->tx, msg);

        virNetServerClientUpdateEvent(client);
        ret = 0;
    }

1495 1496 1497 1498 1499 1500 1501 1502
    return ret;
}

int virNetServerClientSendMessage(virNetServerClientPtr client,
                                  virNetMessagePtr msg)
{
    int ret;

1503
    virObjectLock(client);
1504
    ret = virNetServerClientSendMessageLocked(client, msg);
1505
    virObjectUnlock(client);
1506

1507 1508 1509 1510 1511 1512 1513
    return ret;
}


bool virNetServerClientNeedAuth(virNetServerClientPtr client)
{
    bool need = false;
1514
    virObjectLock(client);
1515
    if (client->auth)
1516
        need = true;
1517
    virObjectUnlock(client);
1518 1519
    return need;
}
1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543


static void
virNetServerClientKeepAliveDeadCB(void *opaque)
{
    virNetServerClientImmediateClose(opaque);
}

static int
virNetServerClientKeepAliveSendCB(void *opaque,
                                  virNetMessagePtr msg)
{
    return virNetServerClientSendMessage(opaque, msg);
}


int
virNetServerClientInitKeepAlive(virNetServerClientPtr client,
                                int interval,
                                unsigned int count)
{
    virKeepAlivePtr ka;
    int ret = -1;

1544
    virObjectLock(client);
1545 1546 1547 1548

    if (!(ka = virKeepAliveNew(interval, count, client,
                               virNetServerClientKeepAliveSendCB,
                               virNetServerClientKeepAliveDeadCB,
1549
                               virObjectFreeCallback)))
1550 1551
        goto cleanup;
    /* keepalive object has a reference to client */
1552
    virObjectRef(client);
1553 1554 1555

    client->keepalive = ka;

1556
 cleanup:
1557
    virObjectUnlock(client);
1558 1559 1560 1561 1562 1563 1564

    return ret;
}

int
virNetServerClientStartKeepAlive(virNetServerClientPtr client)
{
1565 1566
    int ret = -1;

1567
    virObjectLock(client);
1568 1569 1570 1571

    /* The connection might have been closed before we got here and thus the
     * keepalive object could have been removed too.
     */
1572
    if (!client->keepalive) {
1573 1574 1575 1576 1577
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("connection not open"));
        goto cleanup;
    }

1578
    ret = virKeepAliveStart(client->keepalive, 0, 0);
1579

1580
 cleanup:
1581
    virObjectUnlock(client);
1582 1583
    return ret;
}
1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605

int
virNetServerClientGetTransport(virNetServerClientPtr client)
{
    int ret = -1;

    virObjectLock(client);

    if (client->sock && virNetSocketIsLocal(client->sock))
        ret = VIR_CLIENT_TRANS_UNIX;
    else
        ret = VIR_CLIENT_TRANS_TCP;

#ifdef WITH_GNUTLS
    if (client->tls)
        ret = VIR_CLIENT_TRANS_TLS;
#endif

    virObjectUnlock(client);

    return ret;
}
1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616

int
virNetServerClientGetInfo(virNetServerClientPtr client,
                          bool *readonly, const char **sock_addr,
                          virIdentityPtr *identity)
{
    int ret = -1;

    virObjectLock(client);
    *readonly = client->readonly;

1617
    if (!(*sock_addr = virNetServerClientRemoteAddrStringURI(client))) {
1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("No network socket associated with client"));
        goto cleanup;
    }

    if (!client->identity) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       _("No identity information available for client"));
        goto cleanup;
    }

    *identity = virObjectRef(client->identity);

    ret = 0;
 cleanup:
    virObjectUnlock(client);
    return ret;
}