virnetsockettest.c 18.0 KB
Newer Older
1
/*
E
Eric Blake 已提交
2
 * Copyright (C) 2011, 2014 Red Hat, Inc.
3 4 5 6 7 8 9 10 11 12 13 14
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
15
 * License along with this library.  If not, see
O
Osier Yang 已提交
16
 * <http://www.gnu.org/licenses/>.
17 18 19 20 21 22 23 24
 */

#include <config.h>

#include <signal.h>
#ifdef HAVE_IFADDRS_H
# include <ifaddrs.h>
#endif
25
#include <netdb.h>
26 27

#include "testutils.h"
28
#include "virutil.h"
29
#include "virerror.h"
30
#include "viralloc.h"
31
#include "virlog.h"
E
Eric Blake 已提交
32
#include "virfile.h"
33
#include "virstring.h"
34 35 36 37 38

#include "rpc/virnetsocket.h"

#define VIR_FROM_THIS VIR_FROM_RPC

39 40
VIR_LOG_INIT("tests.netsockettest");

41 42 43 44 45 46 47 48 49 50
#if HAVE_IFADDRS_H
# define BASE_PORT 5672

static int
checkProtocols(bool *hasIPv4, bool *hasIPv6,
               int *freePort)
{
    struct sockaddr_in in4;
    struct sockaddr_in6 in6;
    int s4 = -1, s6 = -1;
51
    size_t i;
52 53 54
    int ret = -1;

    *freePort = 0;
55 56
    if (virNetSocketCheckProtocols(hasIPv4, hasIPv6) < 0)
        return -1;
57

58
    for (i = 0; i < 50; i++) {
59 60 61 62
        int only = 1;
        if ((s4 = socket(AF_INET, SOCK_STREAM, 0)) < 0)
            goto cleanup;

63 64 65
        if (*hasIPv6) {
            if ((s6 = socket(AF_INET6, SOCK_STREAM, 0)) < 0)
                goto cleanup;
66

67 68 69
            if (setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, &only, sizeof(only)) < 0)
                goto cleanup;
        }
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

        memset(&in4, 0, sizeof(in4));
        memset(&in6, 0, sizeof(in6));

        in4.sin_family = AF_INET;
        in4.sin_port = htons(BASE_PORT + i);
        in4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
        in6.sin6_family = AF_INET6;
        in6.sin6_port = htons(BASE_PORT + i);
        in6.sin6_addr = in6addr_loopback;

        if (bind(s4, (struct sockaddr *)&in4, sizeof(in4)) < 0) {
            if (errno == EADDRINUSE) {
                VIR_FORCE_CLOSE(s4);
                VIR_FORCE_CLOSE(s6);
                continue;
            }
            goto cleanup;
        }
89 90 91 92 93 94 95 96 97

        if (*hasIPv6) {
            if (bind(s6, (struct sockaddr *)&in6, sizeof(in6)) < 0) {
                if (errno == EADDRINUSE) {
                    VIR_FORCE_CLOSE(s4);
                    VIR_FORCE_CLOSE(s6);
                    continue;
                }
                goto cleanup;
98 99 100 101 102 103 104
            }
        }

        *freePort = BASE_PORT + i;
        break;
    }

J
Jiri Denemark 已提交
105
    VIR_DEBUG("Choose port %d", *freePort);
106 107 108

    ret = 0;

109
 cleanup:
110 111 112 113 114
    VIR_FORCE_CLOSE(s4);
    VIR_FORCE_CLOSE(s6);
    return ret;
}

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
struct testClientData {
    const char *path;
    const char *cnode;
    const char *portstr;
};

static void
testSocketClient(void *opaque)
{
    struct testClientData *data = opaque;
    char c;
    virNetSocketPtr csock = NULL;

    if (data->path) {
        if (virNetSocketNewConnectUNIX(data->path, false,
                                       NULL, &csock) < 0)
            return;
    } else {
        if (virNetSocketNewConnectTCP(data->cnode, data->portstr,
                                      AF_UNSPEC,
                                      &csock) < 0)
            return;
    }

    virNetSocketSetBlocking(csock, true);

    if (virNetSocketRead(csock, &c, 1) != 1) {
        VIR_DEBUG("Cannot read from server");
        goto done;
    }
    if (virNetSocketWrite(csock, &c, 1) != 1) {
        VIR_DEBUG("Cannot write to server");
        goto done;
    }

 done:
    virObjectUnref(csock);
}


static void
testSocketIncoming(virNetSocketPtr sock,
                   int events ATTRIBUTE_UNUSED,
                   void *opaque)
{
    virNetSocketPtr *retsock = opaque;
161
    VIR_DEBUG("Incoming sock=%p events=%d", sock, events);
162 163 164
    *retsock = sock;
}

165

166
struct testSocketData {
167 168 169 170 171
    const char *lnode;
    int port;
    const char *cnode;
};

172 173 174

static int
testSocketAccept(const void *opaque)
175 176 177 178
{
    virNetSocketPtr *lsock = NULL; /* Listen socket */
    size_t nlsock = 0, i;
    virNetSocketPtr ssock = NULL; /* Server socket */
179
    virNetSocketPtr rsock = NULL; /* Incoming client socket */
180
    const struct testSocketData *data = opaque;
181 182
    int ret = -1;
    char portstr[100];
183 184 185
    char *tmpdir = NULL;
    char *path = NULL;
    char template[] = "/tmp/libvirt_XXXXXX";
186 187 188 189 190
    virThread th;
    struct testClientData cdata = { 0 };
    bool goodsock = false;
    char a = 'a';
    char b = '\0';
191

192 193 194 195 196 197 198 199 200
    if (!data) {
        virNetSocketPtr usock;
        tmpdir = mkdtemp(template);
        if (tmpdir == NULL) {
            VIR_WARN("Failed to create temporary directory");
            goto cleanup;
        }
        if (virAsprintf(&path, "%s/test.sock", tmpdir) < 0)
            goto cleanup;
201

202 203 204 205 206 207 208 209 210 211
        if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &usock) < 0)
            goto cleanup;

        if (VIR_ALLOC_N(lsock, 1) < 0) {
            virObjectUnref(usock);
            goto cleanup;
        }

        lsock[0] = usock;
        nlsock = 1;
212 213

        cdata.path = path;
214 215 216 217 218 219
    } else {
        snprintf(portstr, sizeof(portstr), "%d", data->port);
        if (virNetSocketNewListenTCP(data->lnode, portstr,
                                     AF_UNSPEC,
                                     &lsock, &nlsock) < 0)
            goto cleanup;
220 221 222

        cdata.cnode = data->cnode;
        cdata.portstr = portstr;
223
    }
224

225
    for (i = 0; i < nlsock; i++) {
226
        if (virNetSocketListen(lsock[i], 0) < 0)
227 228
            goto cleanup;

229 230 231 232 233
        if (virNetSocketAddIOCallback(lsock[i],
                                      VIR_EVENT_HANDLE_READABLE,
                                      testSocketIncoming,
                                      &rsock,
                                      NULL) < 0) {
234
            goto cleanup;
235
        }
236
    }
237

238 239 240 241 242
    if (virThreadCreate(&th, true,
                        testSocketClient,
                        &cdata) < 0)
        goto cleanup;

243 244 245 246
    while (rsock == NULL) {
        if (virEventRunDefaultImpl() < 0)
            break;
    }
247

248
    for (i = 0; i < nlsock; i++) {
249 250 251
        if (lsock[i] == rsock) {
            goodsock = true;
            break;
252 253 254
        }
    }

255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    if (!goodsock) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       "Unexpected server socket seen");
        goto join;
    }

    if (virNetSocketAccept(rsock, &ssock) < 0)
        goto join;

    if (!ssock) {
        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                       "Client went away unexpectedly");
        goto join;
    }

    virNetSocketSetBlocking(ssock, true);

    if (virNetSocketWrite(ssock, &a, 1) < 0 ||
        virNetSocketRead(ssock, &b, 1) < 0) {
        goto join;
    }

    if (a != b) {
        virReportError(VIR_ERR_INTERNAL_ERROR,
                       "Bad data received '%x' != '%x'", a, b);
        goto join;
    }

    virObjectUnref(ssock);
    ssock = NULL;

286 287
    ret = 0;

288 289 290
 join:
    virThreadJoin(&th);

291
 cleanup:
292
    virObjectUnref(ssock);
293 294 295
    for (i = 0; i < nlsock; i++) {
        virNetSocketRemoveIOCallback(lsock[i]);
        virNetSocketClose(lsock[i]);
296
        virObjectUnref(lsock[i]);
297
    }
298 299
    VIR_FREE(lsock);
    VIR_FREE(path);
300 301
    if (tmpdir)
        rmdir(tmpdir);
302 303
    return ret;
}
304
#endif
305 306


307
#ifndef WIN32
308 309 310 311 312 313 314
static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr lsock = NULL; /* Listen socket */
    virNetSocketPtr ssock = NULL; /* Server socket */
    virNetSocketPtr csock = NULL; /* Client socket */
    int ret = -1;

315
    char *path = NULL;
316 317 318 319 320
    char *tmpdir;
    char template[] = "/tmp/libvirt_XXXXXX";

    tmpdir = mkdtemp(template);
    if (tmpdir == NULL) {
321
        VIR_WARN("Failed to create temporary directory");
322 323
        goto cleanup;
    }
324
    if (virAsprintf(&path, "%s/test.sock", tmpdir) < 0)
325
        goto cleanup;
326

327
    if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &lsock) < 0)
328 329
        goto cleanup;

330
    if (STRNEQ(virNetSocketLocalAddrStringSASL(lsock), "127.0.0.1;0")) {
331 332 333 334
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

335
    if (virNetSocketRemoteAddrStringSASL(lsock) != NULL) {
336 337 338 339
        VIR_DEBUG("Unexpected remote address");
        goto cleanup;
    }

340
    if (virNetSocketListen(lsock, 0) < 0)
341 342 343 344 345
        goto cleanup;

    if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0)
        goto cleanup;

346
    if (STRNEQ(virNetSocketLocalAddrStringSASL(csock), "127.0.0.1;0")) {
347 348 349 350
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

351
    if (STRNEQ(virNetSocketRemoteAddrStringSASL(csock), "127.0.0.1;0")) {
352
        VIR_DEBUG("Unexpected remote address");
353 354 355
        goto cleanup;
    }

356 357 358 359 360
    if (STRNEQ(virNetSocketRemoteAddrStringURI(csock), "127.0.0.1:0")) {
        VIR_DEBUG("Unexpected remote address");
        goto cleanup;
    }

361 362 363 364 365 366 367

    if (virNetSocketAccept(lsock, &ssock) < 0) {
        VIR_DEBUG("Unexpected client socket missing");
        goto cleanup;
    }


368
    if (STRNEQ(virNetSocketLocalAddrStringSASL(ssock), "127.0.0.1;0")) {
369 370 371 372
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

373
    if (STRNEQ(virNetSocketRemoteAddrStringSASL(ssock), "127.0.0.1;0")) {
374
        VIR_DEBUG("Unexpected remote address");
375 376 377
        goto cleanup;
    }

378 379 380 381 382
    if (STRNEQ(virNetSocketRemoteAddrStringURI(ssock), "127.0.0.1:0")) {
        VIR_DEBUG("Unexpected remote address");
        goto cleanup;
    }

383 384 385

    ret = 0;

386
 cleanup:
387
    VIR_FREE(path);
388 389 390
    virObjectUnref(lsock);
    virObjectUnref(ssock);
    virObjectUnref(csock);
391 392
    if (tmpdir)
        rmdir(tmpdir);
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
    return ret;
}

static int testSocketCommandNormal(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr csock = NULL; /* Client socket */
    char buf[100];
    size_t i;
    int ret = -1;
    virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/zero", NULL);
    virCommandAddEnvPassCommon(cmd);

    if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
        goto cleanup;

    virNetSocketSetBlocking(csock, true);

    if (virNetSocketRead(csock, buf, sizeof(buf)) < 0)
        goto cleanup;

413
    for (i = 0; i < sizeof(buf); i++)
414 415 416 417 418
        if (buf[i] != '\0')
            goto cleanup;

    ret = 0;

419
 cleanup:
420
    virObjectUnref(csock);
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
    return ret;
}

static int testSocketCommandFail(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr csock = NULL; /* Client socket */
    char buf[100];
    int ret = -1;
    virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/does-not-exist", NULL);
    virCommandAddEnvPassCommon(cmd);

    if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
        goto cleanup;

    virNetSocketSetBlocking(csock, true);

    if (virNetSocketRead(csock, buf, sizeof(buf)) == 0)
        goto cleanup;

    ret = 0;

442
 cleanup:
443
    virObjectUnref(csock);
444 445 446 447 448 449 450 451 452
    return ret;
}

struct testSSHData {
    const char *nodename;
    const char *service;
    const char *binary;
    const char *username;
    bool noTTY;
453
    bool noVerify;
454
    const char *netcat;
455
    const char *keyfile;
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
    const char *path;

    const char *expectOut;
    bool failConnect;
    bool dieEarly;
};

static int testSocketSSH(const void *opaque)
{
    const struct testSSHData *data = opaque;
    virNetSocketPtr csock = NULL; /* Client socket */
    int ret = -1;
    char buf[1024];

    if (virNetSocketNewConnectSSH(data->nodename,
                                  data->service,
                                  data->binary,
                                  data->username,
                                  data->noTTY,
475
                                  data->noVerify,
476
                                  data->netcat,
477
                                  data->keyfile,
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
                                  data->path,
                                  &csock) < 0)
        goto cleanup;

    virNetSocketSetBlocking(csock, true);

    if (data->failConnect) {
        if (virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
            VIR_DEBUG("Expected connect failure, but got some socket data");
            goto cleanup;
        }
    } else {
        ssize_t rv;
        if ((rv = virNetSocketRead(csock, buf, sizeof(buf)-1)) < 0) {
            VIR_DEBUG("Didn't get any socket data");
            goto cleanup;
        }
        buf[rv] = '\0';

497
        if (STRNEQ(buf, data->expectOut)) {
498
            virTestDifference(stderr, data->expectOut, buf);
499 500 501 502 503 504 505 506 507 508 509 510
            goto cleanup;
        }

        if (data->dieEarly &&
            virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
            VIR_DEBUG("Got too much socket data");
            goto cleanup;
        }
    }

    ret = 0;

511
 cleanup:
512
    virObjectUnref(csock);
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
    return ret;
}

#endif


static int
mymain(void)
{
    int ret = 0;
#ifdef HAVE_IFADDRS_H
    bool hasIPv4, hasIPv6;
    int freePort;
#endif

    signal(SIGPIPE, SIG_IGN);

530 531
    virEventRegisterDefaultImpl();

532 533 534
#ifdef HAVE_IFADDRS_H
    if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) {
        fprintf(stderr, "Cannot identify IPv4/6 availability\n");
535
        return EXIT_FAILURE;
536 537 538
    }

    if (hasIPv4) {
539 540
        struct testSocketData tcpData = { "127.0.0.1", freePort, "127.0.0.1" };
        if (virTestRun("Socket TCP/IPv4 Accept", testSocketAccept, &tcpData) < 0)
541 542 543
            ret = -1;
    }
    if (hasIPv6) {
544 545
        struct testSocketData tcpData = { "::1", freePort, "::1" };
        if (virTestRun("Socket TCP/IPv6 Accept", testSocketAccept, &tcpData) < 0)
546 547 548
            ret = -1;
    }
    if (hasIPv6 && hasIPv4) {
549 550
        struct testSocketData tcpData = { NULL, freePort, "127.0.0.1" };
        if (virTestRun("Socket TCP/IPv4+IPv6 Accept", testSocketAccept, &tcpData) < 0)
551 552 553
            ret = -1;

        tcpData.cnode = "::1";
554
        if (virTestRun("Socket TCP/IPv4+IPv6 Accept", testSocketAccept, &tcpData) < 0)
555 556 557 558 559
            ret = -1;
    }
#endif

#ifndef WIN32
560
    if (virTestRun("Socket UNIX Accept", testSocketAccept, NULL) < 0)
561 562
        ret = -1;

563
    if (virTestRun("Socket UNIX Addrs", testSocketUNIXAddrs, NULL) < 0)
564 565
        ret = -1;

566
    if (virTestRun("Socket External Command /dev/zero", testSocketCommandNormal, NULL) < 0)
567
        ret = -1;
568
    if (virTestRun("Socket External Command /dev/does-not-exist", testSocketCommandFail, NULL) < 0)
569 570 571 572 573
        ret = -1;

    struct testSSHData sshData1 = {
        .nodename = "somehost",
        .path = "/tmp/socket",
574 575 576 577 578 579 580
        .expectOut = "-T -e none -- somehost sh -c '"
                     "if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
                     "'nc' $ARG -U /tmp/socket'\n",
581
    };
582
    if (virTestRun("SSH test 1", testSocketSSH, &sshData1) < 0)
583 584 585 586 587 588 589 590
        ret = -1;

    struct testSSHData sshData2 = {
        .nodename = "somehost",
        .service = "9000",
        .username = "fred",
        .netcat = "netcat",
        .noTTY = true,
591
        .noVerify = false,
592
        .path = "/tmp/socket",
593
        .expectOut = "-p 9000 -l fred -T -e none -o BatchMode=yes -- somehost sh -c '"
594
                     "if 'netcat' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
595 596 597 598
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
599
                     "'netcat' $ARG -U /tmp/socket'\n",
600
    };
601
    if (virTestRun("SSH test 2", testSocketSSH, &sshData2) < 0)
602 603 604
        ret = -1;

    struct testSSHData sshData3 = {
605 606 607 608 609 610
        .nodename = "somehost",
        .service = "9000",
        .username = "fred",
        .netcat = "netcat",
        .noTTY = false,
        .noVerify = true,
611
        .path = "/tmp/socket",
612
        .expectOut = "-p 9000 -l fred -T -e none -o StrictHostKeyChecking=no -- somehost sh -c '"
613
                     "if 'netcat' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
614 615 616 617
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
618
                     "'netcat' $ARG -U /tmp/socket'\n",
619
    };
620
    if (virTestRun("SSH test 3", testSocketSSH, &sshData3) < 0)
621 622 623
        ret = -1;

    struct testSSHData sshData4 = {
624 625 626 627
        .nodename = "nosuchhost",
        .path = "/tmp/socket",
        .failConnect = true,
    };
628
    if (virTestRun("SSH test 4", testSocketSSH, &sshData4) < 0)
629 630 631
        ret = -1;

    struct testSSHData sshData5 = {
632 633
        .nodename = "crashyhost",
        .path = "/tmp/socket",
634
        .expectOut = "-T -e none -- crashyhost sh -c "
635
                     "'if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
636 637 638 639
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
640
                     "'nc' $ARG -U /tmp/socket'\n",
641 642
        .dieEarly = true,
    };
643
    if (virTestRun("SSH test 5", testSocketSSH, &sshData5) < 0)
644 645
        ret = -1;

646 647 648 649 650
    struct testSSHData sshData6 = {
        .nodename = "example.com",
        .path = "/tmp/socket",
        .keyfile = "/root/.ssh/example_key",
        .noVerify = true,
651
        .expectOut = "-i /root/.ssh/example_key -T -e none -o StrictHostKeyChecking=no -- example.com sh -c '"
652
                     "if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
653 654 655 656
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
657
                     "'nc' $ARG -U /tmp/socket'\n",
658
    };
659
    if (virTestRun("SSH test 6", testSocketSSH, &sshData6) < 0)
660 661
        ret = -1;

662 663 664 665
    struct testSSHData sshData7 = {
        .nodename = "somehost",
        .netcat = "nc -4",
        .path = "/tmp/socket",
666 667 668 669 670 671 672
        .expectOut = "-T -e none -- somehost sh -c '"
                     "if ''nc -4'' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
                     "''nc -4'' $ARG -U /tmp/socket'\n",
673
    };
674
    if (virTestRun("SSH test 7", testSocketSSH, &sshData7) < 0)
675 676
        ret = -1;

677 678
#endif

679
    return ret == 0 ? EXIT_SUCCESS : EXIT_FAILURE;
680 681
}

682
VIR_TEST_MAIN(mymain)