virnetsockettest.c 16.3 KB
Newer Older
1
/*
E
Eric Blake 已提交
2
 * Copyright (C) 2011, 2014 Red Hat, Inc.
3 4 5 6 7 8 9 10 11 12 13 14
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
15
 * License along with this library.  If not, see
O
Osier Yang 已提交
16
 * <http://www.gnu.org/licenses/>.
17 18 19 20 21 22 23 24 25 26 27
 *
 * Author: Daniel P. Berrange <berrange@redhat.com>
 */

#include <config.h>

#include <stdlib.h>
#include <signal.h>
#ifdef HAVE_IFADDRS_H
# include <ifaddrs.h>
#endif
28
#include <netdb.h>
29 30

#include "testutils.h"
31
#include "virutil.h"
32
#include "virerror.h"
33
#include "viralloc.h"
34
#include "virlog.h"
E
Eric Blake 已提交
35
#include "virfile.h"
36
#include "virstring.h"
37 38 39 40 41

#include "rpc/virnetsocket.h"

#define VIR_FROM_THIS VIR_FROM_RPC

42 43
VIR_LOG_INIT("tests.netsockettest");

44 45 46 47 48 49 50 51 52 53
#if HAVE_IFADDRS_H
# define BASE_PORT 5672

static int
checkProtocols(bool *hasIPv4, bool *hasIPv6,
               int *freePort)
{
    struct sockaddr_in in4;
    struct sockaddr_in6 in6;
    int s4 = -1, s6 = -1;
54
    size_t i;
55 56 57
    int ret = -1;

    *freePort = 0;
58 59
    if (virNetSocketCheckProtocols(hasIPv4, hasIPv6) < 0)
        return -1;
60

61
    for (i = 0; i < 50; i++) {
62 63 64 65
        int only = 1;
        if ((s4 = socket(AF_INET, SOCK_STREAM, 0)) < 0)
            goto cleanup;

66 67 68
        if (*hasIPv6) {
            if ((s6 = socket(AF_INET6, SOCK_STREAM, 0)) < 0)
                goto cleanup;
69

70 71 72
            if (setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, &only, sizeof(only)) < 0)
                goto cleanup;
        }
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

        memset(&in4, 0, sizeof(in4));
        memset(&in6, 0, sizeof(in6));

        in4.sin_family = AF_INET;
        in4.sin_port = htons(BASE_PORT + i);
        in4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
        in6.sin6_family = AF_INET6;
        in6.sin6_port = htons(BASE_PORT + i);
        in6.sin6_addr = in6addr_loopback;

        if (bind(s4, (struct sockaddr *)&in4, sizeof(in4)) < 0) {
            if (errno == EADDRINUSE) {
                VIR_FORCE_CLOSE(s4);
                VIR_FORCE_CLOSE(s6);
                continue;
            }
            goto cleanup;
        }
92 93 94 95 96 97 98 99 100

        if (*hasIPv6) {
            if (bind(s6, (struct sockaddr *)&in6, sizeof(in6)) < 0) {
                if (errno == EADDRINUSE) {
                    VIR_FORCE_CLOSE(s4);
                    VIR_FORCE_CLOSE(s6);
                    continue;
                }
                goto cleanup;
101 102 103 104 105 106 107
            }
        }

        *freePort = BASE_PORT + i;
        break;
    }

J
Jiri Denemark 已提交
108
    VIR_DEBUG("Choose port %d", *freePort);
109 110 111

    ret = 0;

112
 cleanup:
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    VIR_FORCE_CLOSE(s4);
    VIR_FORCE_CLOSE(s6);
    return ret;
}


struct testTCPData {
    const char *lnode;
    int port;
    const char *cnode;
};

static int testSocketTCPAccept(const void *opaque)
{
    virNetSocketPtr *lsock = NULL; /* Listen socket */
    size_t nlsock = 0, i;
    virNetSocketPtr ssock = NULL; /* Server socket */
    virNetSocketPtr csock = NULL; /* Client socket */
    const struct testTCPData *data = opaque;
    int ret = -1;
    char portstr[100];

    snprintf(portstr, sizeof(portstr), "%d", data->port);

137 138 139
    if (virNetSocketNewListenTCP(data->lnode, portstr,
                                 AF_UNSPEC,
                                 &lsock, &nlsock) < 0)
140 141
        goto cleanup;

142
    for (i = 0; i < nlsock; i++) {
143
        if (virNetSocketListen(lsock[i], 0) < 0)
144 145 146
            goto cleanup;
    }

147 148 149
    if (virNetSocketNewConnectTCP(data->cnode, portstr,
                                  AF_UNSPEC,
                                  &csock) < 0)
150 151
        goto cleanup;

152
    virObjectUnref(csock);
153

154
    for (i = 0; i < nlsock; i++) {
155 156 157 158 159 160 161 162
        if (virNetSocketAccept(lsock[i], &ssock) != -1 && ssock) {
            char c = 'a';
            if (virNetSocketWrite(ssock, &c, 1) != -1 &&
                virNetSocketRead(ssock, &c, 1) != -1) {
                VIR_DEBUG("Unexpected client socket present");
                goto cleanup;
            }
        }
163
        virObjectUnref(ssock);
164 165 166 167 168
        ssock = NULL;
    }

    ret = 0;

169
 cleanup:
170
    virObjectUnref(ssock);
171
    for (i = 0; i < nlsock; i++)
172
        virObjectUnref(lsock[i]);
173 174 175 176 177 178 179 180 181 182 183 184 185 186
    VIR_FREE(lsock);
    return ret;
}
#endif


#ifndef WIN32
static int testSocketUNIXAccept(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr lsock = NULL; /* Listen socket */
    virNetSocketPtr ssock = NULL; /* Server socket */
    virNetSocketPtr csock = NULL; /* Client socket */
    int ret = -1;

187
    char *path = NULL;
188 189 190 191 192
    char *tmpdir;
    char template[] = "/tmp/libvirt_XXXXXX";

    tmpdir = mkdtemp(template);
    if (tmpdir == NULL) {
193
        VIR_WARN("Failed to create temporary directory");
194 195
        goto cleanup;
    }
196
    if (virAsprintf(&path, "%s/test.sock", tmpdir) < 0)
197
        goto cleanup;
198

199
    if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &lsock) < 0)
200 201
        goto cleanup;

202
    if (virNetSocketListen(lsock, 0) < 0)
203 204 205 206 207
        goto cleanup;

    if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0)
        goto cleanup;

208
    virObjectUnref(csock);
209 210 211 212 213 214 215 216 217 218 219

    if (virNetSocketAccept(lsock, &ssock) != -1) {
        char c = 'a';
        if (virNetSocketWrite(ssock, &c, 1) != -1) {
            VIR_DEBUG("Unexpected client socket present");
            goto cleanup;
        }
    }

    ret = 0;

220
 cleanup:
221
    VIR_FREE(path);
222 223
    virObjectUnref(lsock);
    virObjectUnref(ssock);
224 225
    if (tmpdir)
        rmdir(tmpdir);
226 227 228 229 230 231 232 233 234 235 236
    return ret;
}


static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr lsock = NULL; /* Listen socket */
    virNetSocketPtr ssock = NULL; /* Server socket */
    virNetSocketPtr csock = NULL; /* Client socket */
    int ret = -1;

237
    char *path = NULL;
238 239 240 241 242
    char *tmpdir;
    char template[] = "/tmp/libvirt_XXXXXX";

    tmpdir = mkdtemp(template);
    if (tmpdir == NULL) {
243
        VIR_WARN("Failed to create temporary directory");
244 245
        goto cleanup;
    }
246
    if (virAsprintf(&path, "%s/test.sock", tmpdir) < 0)
247
        goto cleanup;
248

249
    if (virNetSocketNewListenUNIX(path, 0700, -1, getegid(), &lsock) < 0)
250 251
        goto cleanup;

252
    if (STRNEQ(virNetSocketLocalAddrString(lsock), "127.0.0.1;0")) {
253 254 255 256 257 258 259 260 261
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

    if (virNetSocketRemoteAddrString(lsock) != NULL) {
        VIR_DEBUG("Unexpected remote address");
        goto cleanup;
    }

262
    if (virNetSocketListen(lsock, 0) < 0)
263 264 265 266 267
        goto cleanup;

    if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0)
        goto cleanup;

268
    if (STRNEQ(virNetSocketLocalAddrString(csock), "127.0.0.1;0")) {
269 270 271 272
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

273
    if (STRNEQ(virNetSocketRemoteAddrString(csock), "127.0.0.1;0")) {
274 275 276 277
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

278 279 280 281 282
    if (STRNEQ(virNetSocketRemoteAddrStringURI(csock), "127.0.0.1:0")) {
        VIR_DEBUG("Unexpected remote address");
        goto cleanup;
    }

283 284 285 286 287 288 289

    if (virNetSocketAccept(lsock, &ssock) < 0) {
        VIR_DEBUG("Unexpected client socket missing");
        goto cleanup;
    }


290
    if (STRNEQ(virNetSocketLocalAddrString(ssock), "127.0.0.1;0")) {
291 292 293 294
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

295
    if (STRNEQ(virNetSocketRemoteAddrString(ssock), "127.0.0.1;0")) {
296 297 298 299
        VIR_DEBUG("Unexpected local address");
        goto cleanup;
    }

300 301 302 303 304
    if (STRNEQ(virNetSocketRemoteAddrStringURI(ssock), "127.0.0.1:0")) {
        VIR_DEBUG("Unexpected remote address");
        goto cleanup;
    }

305 306 307

    ret = 0;

308
 cleanup:
309
    VIR_FREE(path);
310 311 312
    virObjectUnref(lsock);
    virObjectUnref(ssock);
    virObjectUnref(csock);
313 314
    if (tmpdir)
        rmdir(tmpdir);
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
    return ret;
}

static int testSocketCommandNormal(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr csock = NULL; /* Client socket */
    char buf[100];
    size_t i;
    int ret = -1;
    virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/zero", NULL);
    virCommandAddEnvPassCommon(cmd);

    if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
        goto cleanup;

    virNetSocketSetBlocking(csock, true);

    if (virNetSocketRead(csock, buf, sizeof(buf)) < 0)
        goto cleanup;

335
    for (i = 0; i < sizeof(buf); i++)
336 337 338 339 340
        if (buf[i] != '\0')
            goto cleanup;

    ret = 0;

341
 cleanup:
342
    virObjectUnref(csock);
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
    return ret;
}

static int testSocketCommandFail(const void *data ATTRIBUTE_UNUSED)
{
    virNetSocketPtr csock = NULL; /* Client socket */
    char buf[100];
    int ret = -1;
    virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/does-not-exist", NULL);
    virCommandAddEnvPassCommon(cmd);

    if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
        goto cleanup;

    virNetSocketSetBlocking(csock, true);

    if (virNetSocketRead(csock, buf, sizeof(buf)) == 0)
        goto cleanup;

    ret = 0;

364
 cleanup:
365
    virObjectUnref(csock);
366 367 368 369 370 371 372 373 374
    return ret;
}

struct testSSHData {
    const char *nodename;
    const char *service;
    const char *binary;
    const char *username;
    bool noTTY;
375
    bool noVerify;
376
    const char *netcat;
377
    const char *keyfile;
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
    const char *path;

    const char *expectOut;
    bool failConnect;
    bool dieEarly;
};

static int testSocketSSH(const void *opaque)
{
    const struct testSSHData *data = opaque;
    virNetSocketPtr csock = NULL; /* Client socket */
    int ret = -1;
    char buf[1024];

    if (virNetSocketNewConnectSSH(data->nodename,
                                  data->service,
                                  data->binary,
                                  data->username,
                                  data->noTTY,
397
                                  data->noVerify,
398
                                  data->netcat,
399
                                  data->keyfile,
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
                                  data->path,
                                  &csock) < 0)
        goto cleanup;

    virNetSocketSetBlocking(csock, true);

    if (data->failConnect) {
        if (virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
            VIR_DEBUG("Expected connect failure, but got some socket data");
            goto cleanup;
        }
    } else {
        ssize_t rv;
        if ((rv = virNetSocketRead(csock, buf, sizeof(buf)-1)) < 0) {
            VIR_DEBUG("Didn't get any socket data");
            goto cleanup;
        }
        buf[rv] = '\0';

419
        if (STRNEQ(buf, data->expectOut)) {
420
            virTestDifference(stderr, data->expectOut, buf);
421 422 423 424 425 426 427 428 429 430 431 432
            goto cleanup;
        }

        if (data->dieEarly &&
            virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
            VIR_DEBUG("Got too much socket data");
            goto cleanup;
        }
    }

    ret = 0;

433
 cleanup:
434
    virObjectUnref(csock);
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
    return ret;
}

#endif


static int
mymain(void)
{
    int ret = 0;
#ifdef HAVE_IFADDRS_H
    bool hasIPv4, hasIPv6;
    int freePort;
#endif

    signal(SIGPIPE, SIG_IGN);

#ifdef HAVE_IFADDRS_H
    if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) {
        fprintf(stderr, "Cannot identify IPv4/6 availability\n");
455
        return EXIT_FAILURE;
456 457 458 459
    }

    if (hasIPv4) {
        struct testTCPData tcpData = { "127.0.0.1", freePort, "127.0.0.1" };
460
        if (virTestRun("Socket TCP/IPv4 Accept", testSocketTCPAccept, &tcpData) < 0)
461 462 463 464
            ret = -1;
    }
    if (hasIPv6) {
        struct testTCPData tcpData = { "::1", freePort, "::1" };
465
        if (virTestRun("Socket TCP/IPv6 Accept", testSocketTCPAccept, &tcpData) < 0)
466 467 468 469
            ret = -1;
    }
    if (hasIPv6 && hasIPv4) {
        struct testTCPData tcpData = { NULL, freePort, "127.0.0.1" };
470
        if (virTestRun("Socket TCP/IPv4+IPv6 Accept", testSocketTCPAccept, &tcpData) < 0)
471 472 473
            ret = -1;

        tcpData.cnode = "::1";
474
        if (virTestRun("Socket TCP/IPv4+IPv6 Accept", testSocketTCPAccept, &tcpData) < 0)
475 476 477 478 479
            ret = -1;
    }
#endif

#ifndef WIN32
480
    if (virTestRun("Socket UNIX Accept", testSocketUNIXAccept, NULL) < 0)
481 482
        ret = -1;

483
    if (virTestRun("Socket UNIX Addrs", testSocketUNIXAddrs, NULL) < 0)
484 485
        ret = -1;

486
    if (virTestRun("Socket External Command /dev/zero", testSocketCommandNormal, NULL) < 0)
487
        ret = -1;
488
    if (virTestRun("Socket External Command /dev/does-not-exist", testSocketCommandFail, NULL) < 0)
489 490 491 492 493
        ret = -1;

    struct testSSHData sshData1 = {
        .nodename = "somehost",
        .path = "/tmp/socket",
494
        .expectOut = "somehost sh -c 'if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
495 496 497 498
                                         "ARG=-q0;"
                                     "else "
                                         "ARG=;"
                                     "fi;"
499
                                     "'nc' $ARG -U /tmp/socket'\n",
500
    };
501
    if (virTestRun("SSH test 1", testSocketSSH, &sshData1) < 0)
502 503 504 505 506 507 508 509
        ret = -1;

    struct testSSHData sshData2 = {
        .nodename = "somehost",
        .service = "9000",
        .username = "fred",
        .netcat = "netcat",
        .noTTY = true,
510
        .noVerify = false,
511
        .path = "/tmp/socket",
512
        .expectOut = "-p 9000 -l fred -T -o BatchMode=yes -e none somehost sh -c '"
513
                     "if 'netcat' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
514 515 516 517
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
518
                     "'netcat' $ARG -U /tmp/socket'\n",
519
    };
520
    if (virTestRun("SSH test 2", testSocketSSH, &sshData2) < 0)
521 522 523
        ret = -1;

    struct testSSHData sshData3 = {
524 525 526 527 528 529
        .nodename = "somehost",
        .service = "9000",
        .username = "fred",
        .netcat = "netcat",
        .noTTY = false,
        .noVerify = true,
530
        .path = "/tmp/socket",
531
        .expectOut = "-p 9000 -l fred -o StrictHostKeyChecking=no somehost sh -c '"
532
                     "if 'netcat' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
533 534 535 536
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
537
                     "'netcat' $ARG -U /tmp/socket'\n",
538
    };
539
    if (virTestRun("SSH test 3", testSocketSSH, &sshData3) < 0)
540 541 542
        ret = -1;

    struct testSSHData sshData4 = {
543 544 545 546
        .nodename = "nosuchhost",
        .path = "/tmp/socket",
        .failConnect = true,
    };
547
    if (virTestRun("SSH test 4", testSocketSSH, &sshData4) < 0)
548 549 550
        ret = -1;

    struct testSSHData sshData5 = {
551 552
        .nodename = "crashyhost",
        .path = "/tmp/socket",
553
        .expectOut = "crashyhost sh -c "
554
                     "'if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
555 556 557 558
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
559
                     "'nc' $ARG -U /tmp/socket'\n",
560 561
        .dieEarly = true,
    };
562
    if (virTestRun("SSH test 5", testSocketSSH, &sshData5) < 0)
563 564
        ret = -1;

565 566 567 568 569
    struct testSSHData sshData6 = {
        .nodename = "example.com",
        .path = "/tmp/socket",
        .keyfile = "/root/.ssh/example_key",
        .noVerify = true,
570
        .expectOut = "-i /root/.ssh/example_key -o StrictHostKeyChecking=no example.com sh -c '"
571
                     "if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
572 573 574 575
                         "ARG=-q0;"
                     "else "
                         "ARG=;"
                     "fi;"
576
                     "'nc' $ARG -U /tmp/socket'\n",
577
    };
578
    if (virTestRun("SSH test 6", testSocketSSH, &sshData6) < 0)
579 580
        ret = -1;

581 582 583 584 585 586 587 588 589 590 591
    struct testSSHData sshData7 = {
        .nodename = "somehost",
        .netcat = "nc -4",
        .path = "/tmp/socket",
        .expectOut = "somehost sh -c 'if ''nc -4'' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then "
                                         "ARG=-q0;"
                                     "else "
                                         "ARG=;"
                                     "fi;"
                                     "''nc -4'' $ARG -U /tmp/socket'\n",
    };
592
    if (virTestRun("SSH test 7", testSocketSSH, &sshData7) < 0)
593 594
        ret = -1;

595 596
#endif

597
    return ret == 0 ? EXIT_SUCCESS : EXIT_FAILURE;
598 599 600
}

VIRT_TEST_MAIN(mymain)