select_reuseport.c 21.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2018 Facebook */

#include <stdlib.h>
#include <unistd.h>
#include <stdbool.h>
#include <string.h>
#include <errno.h>
#include <assert.h>
#include <fcntl.h>
#include <linux/bpf.h>
#include <linux/err.h>
#include <linux/types.h>
#include <linux/if_ether.h>
#include <sys/types.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <bpf/bpf.h>
#include <bpf/libbpf.h>
#include "bpf_rlimit.h"
#include "bpf_util.h"
23 24

#include "test_progs.h"
25 26
#include "test_select_reuseport_common.h"

27
#define MAX_TEST_NAME 80
28 29 30 31 32 33 34 35 36 37
#define MIN_TCPHDR_LEN 20
#define UDPHDR_LEN 8

#define TCP_SYNCOOKIE_SYSCTL "/proc/sys/net/ipv4/tcp_syncookies"
#define TCP_FO_SYSCTL "/proc/sys/net/ipv4/tcp_fastopen"
#define REUSEPORT_ARRAY_SIZE 32

static int result_map, tmp_index_ovr_map, linum_map, data_check_map;
static enum result expected_results[NR_RESULTS];
static int sk_fds[REUSEPORT_ARRAY_SIZE];
38
static int reuseport_array = -1, outer_map = -1;
39
static int select_by_skb_data_prog;
40
static int saved_tcp_syncookie = -1;
41
static struct bpf_object *obj;
42
static int saved_tcp_fo = -1;
43 44 45 46 47 48 49 50 51
static __u32 index_zero;
static int epfd;

static union sa46 {
	struct sockaddr_in6 v6;
	struct sockaddr_in v4;
	sa_family_t family;
} srv_sa;

52 53 54 55
#define RET_IF(condition, tag, format...) ({				\
	if (CHECK_FAIL(condition)) {					\
		printf(tag " " format);					\
		return;							\
56 57 58
	}								\
})

59
#define RET_ERR(condition, tag, format...) ({				\
60 61
	if (CHECK_FAIL(condition)) {					\
		printf(tag " " format);					\
62 63 64 65 66
		return -1;						\
	}								\
})

static int create_maps(void)
67 68 69 70 71 72 73 74 75 76 77
{
	struct bpf_create_map_attr attr = {};

	/* Creating reuseport_array */
	attr.name = "reuseport_array";
	attr.map_type = BPF_MAP_TYPE_REUSEPORT_SOCKARRAY;
	attr.key_size = sizeof(__u32);
	attr.value_size = sizeof(__u32);
	attr.max_entries = REUSEPORT_ARRAY_SIZE;

	reuseport_array = bpf_create_map_xattr(&attr);
78 79
	RET_ERR(reuseport_array == -1, "creating reuseport_array",
		"reuseport_array:%d errno:%d\n", reuseport_array, errno);
80 81 82 83 84 85 86 87 88

	/* Creating outer_map */
	attr.name = "outer_map";
	attr.map_type = BPF_MAP_TYPE_ARRAY_OF_MAPS;
	attr.key_size = sizeof(__u32);
	attr.value_size = sizeof(__u32);
	attr.max_entries = 1;
	attr.inner_map_fd = reuseport_array;
	outer_map = bpf_create_map_xattr(&attr);
89 90 91 92
	RET_ERR(outer_map == -1, "creating outer_map",
		"outer_map:%d errno:%d\n", outer_map, errno);

	return 0;
93 94
}

95
static int prepare_bpf_obj(void)
96 97 98 99 100
{
	struct bpf_program *prog;
	struct bpf_map *map;
	int err;

101
	obj = bpf_object__open("test_select_reuseport_kern.o");
102 103
	RET_ERR(IS_ERR_OR_NULL(obj), "open test_select_reuseport_kern.o",
		"obj:%p PTR_ERR(obj):%ld\n", obj, PTR_ERR(obj));
104 105

	map = bpf_object__find_map_by_name(obj, "outer_map");
106
	RET_ERR(!map, "find outer_map", "!map\n");
107
	err = bpf_map__reuse_fd(map, outer_map);
108
	RET_ERR(err, "reuse outer_map", "err:%d\n", err);
109 110

	err = bpf_object__load(obj);
111
	RET_ERR(err, "load bpf_object", "err:%d\n", err);
112

113
	prog = bpf_program__next(NULL, obj);
114
	RET_ERR(!prog, "get first bpf_program", "!prog\n");
115
	select_by_skb_data_prog = bpf_program__fd(prog);
116 117
	RET_ERR(select_by_skb_data_prog == -1, "get prog fd",
		"select_by_skb_data_prog:%d\n", select_by_skb_data_prog);
118 119

	map = bpf_object__find_map_by_name(obj, "result_map");
120
	RET_ERR(!map, "find result_map", "!map\n");
121
	result_map = bpf_map__fd(map);
122 123
	RET_ERR(result_map == -1, "get result_map fd",
		"result_map:%d\n", result_map);
124 125

	map = bpf_object__find_map_by_name(obj, "tmp_index_ovr_map");
126
	RET_ERR(!map, "find tmp_index_ovr_map\n", "!map");
127
	tmp_index_ovr_map = bpf_map__fd(map);
128 129
	RET_ERR(tmp_index_ovr_map == -1, "get tmp_index_ovr_map fd",
		"tmp_index_ovr_map:%d\n", tmp_index_ovr_map);
130 131

	map = bpf_object__find_map_by_name(obj, "linum_map");
132
	RET_ERR(!map, "find linum_map", "!map\n");
133
	linum_map = bpf_map__fd(map);
134 135
	RET_ERR(linum_map == -1, "get linum_map fd",
		"linum_map:%d\n", linum_map);
136 137

	map = bpf_object__find_map_by_name(obj, "data_check_map");
138
	RET_ERR(!map, "find data_check_map", "!map\n");
139
	data_check_map = bpf_map__fd(map);
140 141 142 143
	RET_ERR(data_check_map == -1, "get data_check_map fd",
		"data_check_map:%d\n", data_check_map);

	return 0;
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
}

static void sa46_init_loopback(union sa46 *sa, sa_family_t family)
{
	memset(sa, 0, sizeof(*sa));
	sa->family = family;
	if (sa->family == AF_INET6)
		sa->v6.sin6_addr = in6addr_loopback;
	else
		sa->v4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
}

static void sa46_init_inany(union sa46 *sa, sa_family_t family)
{
	memset(sa, 0, sizeof(*sa));
	sa->family = family;
	if (sa->family == AF_INET6)
		sa->v6.sin6_addr = in6addr_any;
	else
		sa->v4.sin_addr.s_addr = INADDR_ANY;
}

static int read_int_sysctl(const char *sysctl)
{
	char buf[16];
	int fd, ret;

	fd = open(sysctl, 0);
172 173
	RET_ERR(fd == -1, "open(sysctl)",
		"sysctl:%s fd:%d errno:%d\n", sysctl, fd, errno);
174 175

	ret = read(fd, buf, sizeof(buf));
176 177
	RET_ERR(ret <= 0, "read(sysctl)",
		"sysctl:%s ret:%d errno:%d\n", sysctl, ret, errno);
178

179
	close(fd);
180 181 182
	return atoi(buf);
}

183
static int write_int_sysctl(const char *sysctl, int v)
184 185 186 187 188
{
	int fd, ret, size;
	char buf[16];

	fd = open(sysctl, O_RDWR);
189 190
	RET_ERR(fd == -1, "open(sysctl)",
		"sysctl:%s fd:%d errno:%d\n", sysctl, fd, errno);
191 192 193

	size = snprintf(buf, sizeof(buf), "%d", v);
	ret = write(fd, buf, size);
194 195 196 197
	RET_ERR(ret != size, "write(sysctl)",
		"sysctl:%s ret:%d size:%d errno:%d\n",
		sysctl, ret, size, errno);

198
	close(fd);
199
	return 0;
200 201 202 203
}

static void restore_sysctls(void)
{
204 205 206 207
	if (saved_tcp_fo != -1)
		write_int_sysctl(TCP_FO_SYSCTL, saved_tcp_fo);
	if (saved_tcp_syncookie != -1)
		write_int_sysctl(TCP_SYNCOOKIE_SYSCTL, saved_tcp_syncookie);
208 209
}

210
static int enable_fastopen(void)
211 212 213 214
{
	int fo;

	fo = read_int_sysctl(TCP_FO_SYSCTL);
215 216 217 218
	if (fo < 0)
		return -1;

	return write_int_sysctl(TCP_FO_SYSCTL, fo | 7);
219 220
}

221
static int enable_syncookie(void)
222
{
223
	return write_int_sysctl(TCP_SYNCOOKIE_SYSCTL, 2);
224 225
}

226
static int disable_syncookie(void)
227
{
228
	return write_int_sysctl(TCP_SYNCOOKIE_SYSCTL, 0);
229 230
}

231
static long get_linum(void)
232 233 234 235 236
{
	__u32 linum;
	int err;

	err = bpf_map_lookup_elem(linum_map, &index_zero, &linum);
237 238
	RET_ERR(err == -1, "lookup_elem(linum_map)", "err:%d errno:%d\n",
		err, errno);
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

	return linum;
}

static void check_data(int type, sa_family_t family, const struct cmd *cmd,
		       int cli_fd)
{
	struct data_check expected = {}, result;
	union sa46 cli_sa;
	socklen_t addrlen;
	int err;

	addrlen = sizeof(cli_sa);
	err = getsockname(cli_fd, (struct sockaddr *)&cli_sa,
			  &addrlen);
254 255
	RET_IF(err == -1, "getsockname(cli_fd)", "err:%d errno:%d\n",
	       err, errno);
256 257

	err = bpf_map_lookup_elem(data_check_map, &index_zero, &result);
258 259
	RET_IF(err == -1, "lookup_elem(data_check_map)", "err:%d errno:%d\n",
	       err, errno);
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300

	if (type == SOCK_STREAM) {
		expected.len = MIN_TCPHDR_LEN;
		expected.ip_protocol = IPPROTO_TCP;
	} else {
		expected.len = UDPHDR_LEN;
		expected.ip_protocol = IPPROTO_UDP;
	}

	if (family == AF_INET6) {
		expected.eth_protocol = htons(ETH_P_IPV6);
		expected.bind_inany = !srv_sa.v6.sin6_addr.s6_addr32[3] &&
			!srv_sa.v6.sin6_addr.s6_addr32[2] &&
			!srv_sa.v6.sin6_addr.s6_addr32[1] &&
			!srv_sa.v6.sin6_addr.s6_addr32[0];

		memcpy(&expected.skb_addrs[0], cli_sa.v6.sin6_addr.s6_addr32,
		       sizeof(cli_sa.v6.sin6_addr));
		memcpy(&expected.skb_addrs[4], &in6addr_loopback,
		       sizeof(in6addr_loopback));
		expected.skb_ports[0] = cli_sa.v6.sin6_port;
		expected.skb_ports[1] = srv_sa.v6.sin6_port;
	} else {
		expected.eth_protocol = htons(ETH_P_IP);
		expected.bind_inany = !srv_sa.v4.sin_addr.s_addr;

		expected.skb_addrs[0] = cli_sa.v4.sin_addr.s_addr;
		expected.skb_addrs[1] = htonl(INADDR_LOOPBACK);
		expected.skb_ports[0] = cli_sa.v4.sin_port;
		expected.skb_ports[1] = srv_sa.v4.sin_port;
	}

	if (memcmp(&result, &expected, offsetof(struct data_check,
						equal_check_end))) {
		printf("unexpected data_check\n");
		printf("  result: (0x%x, %u, %u)\n",
		       result.eth_protocol, result.ip_protocol,
		       result.bind_inany);
		printf("expected: (0x%x, %u, %u)\n",
		       expected.eth_protocol, expected.ip_protocol,
		       expected.bind_inany);
301 302
		RET_IF(1, "data_check result != expected",
		       "bpf_prog_linum:%ld\n", get_linum());
303 304
	}

305 306
	RET_IF(!result.hash, "data_check result.hash empty",
	       "result.hash:%u", result.hash);
307 308 309

	expected.len += cmd ? sizeof(*cmd) : 0;
	if (type == SOCK_STREAM)
310 311 312
		RET_IF(expected.len > result.len, "expected.len > result.len",
		       "expected.len:%u result.len:%u bpf_prog_linum:%ld\n",
		       expected.len, result.len, get_linum());
313
	else
314 315 316
		RET_IF(expected.len != result.len, "expected.len != result.len",
		       "expected.len:%u result.len:%u bpf_prog_linum:%ld\n",
		       expected.len, result.len, get_linum());
317 318
}

319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
static const char *result_to_str(enum result res)
{
	switch (res) {
	case DROP_ERR_INNER_MAP:
		return "DROP_ERR_INNER_MAP";
	case DROP_ERR_SKB_DATA:
		return "DROP_ERR_SKB_DATA";
	case DROP_ERR_SK_SELECT_REUSEPORT:
		return "DROP_ERR_SK_SELECT_REUSEPORT";
	case DROP_MISC:
		return "DROP_MISC";
	case PASS:
		return "PASS";
	case PASS_ERR_SK_SELECT_REUSEPORT:
		return "PASS_ERR_SK_SELECT_REUSEPORT";
	default:
		return "UNKNOWN";
	}
}

339 340 341 342 343 344 345 346
static void check_results(void)
{
	__u32 results[NR_RESULTS];
	__u32 i, broken = 0;
	int err;

	for (i = 0; i < NR_RESULTS; i++) {
		err = bpf_map_lookup_elem(result_map, &i, &results[i]);
347 348
		RET_IF(err == -1, "lookup_elem(result_map)",
		       "i:%u err:%d errno:%d\n", i, err, errno);
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
	}

	for (i = 0; i < NR_RESULTS; i++) {
		if (results[i] != expected_results[i]) {
			broken = i;
			break;
		}
	}

	if (i == NR_RESULTS)
		return;

	printf("unexpected result\n");
	printf(" result: [");
	printf("%u", results[0]);
	for (i = 1; i < NR_RESULTS; i++)
		printf(", %u", results[i]);
	printf("]\n");

	printf("expected: [");
	printf("%u", expected_results[0]);
	for (i = 1; i < NR_RESULTS; i++)
		printf(", %u", expected_results[i]);
	printf("]\n");

374 375 376 377
	printf("mismatch on %s (bpf_prog_linum:%ld)\n", result_to_str(broken),
	       get_linum());

	CHECK_FAIL(true);
378 379 380 381 382 383 384 385 386
}

static int send_data(int type, sa_family_t family, void *data, size_t len,
		     enum result expected)
{
	union sa46 cli_sa;
	int fd, err;

	fd = socket(family, type, 0);
387
	RET_ERR(fd == -1, "socket()", "fd:%d errno:%d\n", fd, errno);
388 389 390

	sa46_init_loopback(&cli_sa, family);
	err = bind(fd, (struct sockaddr *)&cli_sa, sizeof(cli_sa));
391
	RET_ERR(fd == -1, "bind(cli_sa)", "err:%d errno:%d\n", err, errno);
392 393 394

	err = sendto(fd, data, len, MSG_FASTOPEN, (struct sockaddr *)&srv_sa,
		     sizeof(srv_sa));
395 396 397
	RET_ERR(err != len && expected >= PASS,
		"sendto()", "family:%u err:%d errno:%d expected:%d\n",
		family, err, errno, expected);
398 399 400 401 402 403 404 405 406 407 408 409 410 411

	return fd;
}

static void do_test(int type, sa_family_t family, struct cmd *cmd,
		    enum result expected)
{
	int nev, srv_fd, cli_fd;
	struct epoll_event ev;
	struct cmd rcv_cmd;
	ssize_t nread;

	cli_fd = send_data(type, family, cmd, cmd ? sizeof(*cmd) : 0,
			   expected);
412 413
	if (cli_fd < 0)
		return;
414
	nev = epoll_wait(epfd, &ev, 1, expected >= PASS ? 5 : 0);
415 416 417 418 419 420 421
	RET_IF((nev <= 0 && expected >= PASS) ||
	       (nev > 0 && expected < PASS),
	       "nev <> expected",
	       "nev:%d expected:%d type:%d family:%d data:(%d, %d)\n",
	       nev, expected, type, family,
	       cmd ? cmd->reuseport_index : -1,
	       cmd ? cmd->pass_on_failure : -1);
422 423 424 425 426 427
	check_results();
	check_data(type, family, cmd, cli_fd);

	if (expected < PASS)
		return;

428 429 430 431 432
	RET_IF(expected != PASS_ERR_SK_SELECT_REUSEPORT &&
	       cmd->reuseport_index != ev.data.u32,
	       "check cmd->reuseport_index",
	       "cmd:(%u, %u) ev.data.u32:%u\n",
	       cmd->pass_on_failure, cmd->reuseport_index, ev.data.u32);
433 434 435 436 437

	srv_fd = sk_fds[ev.data.u32];
	if (type == SOCK_STREAM) {
		int new_fd = accept(srv_fd, NULL, 0);

438 439 440
		RET_IF(new_fd == -1, "accept(srv_fd)",
		       "ev.data.u32:%u new_fd:%d errno:%d\n",
		       ev.data.u32, new_fd, errno);
441 442

		nread = recv(new_fd, &rcv_cmd, sizeof(rcv_cmd), MSG_DONTWAIT);
443 444 445 446
		RET_IF(nread != sizeof(rcv_cmd),
		       "recv(new_fd)",
		       "ev.data.u32:%u nread:%zd sizeof(rcv_cmd):%zu errno:%d\n",
		       ev.data.u32, nread, sizeof(rcv_cmd), errno);
447 448 449 450

		close(new_fd);
	} else {
		nread = recv(srv_fd, &rcv_cmd, sizeof(rcv_cmd), MSG_DONTWAIT);
451 452 453 454
		RET_IF(nread != sizeof(rcv_cmd),
		       "recv(sk_fds)",
		       "ev.data.u32:%u nread:%zd sizeof(rcv_cmd):%zu errno:%d\n",
		       ev.data.u32, nread, sizeof(rcv_cmd), errno);
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
	}

	close(cli_fd);
}

static void test_err_inner_map(int type, sa_family_t family)
{
	struct cmd cmd = {
		.reuseport_index = 0,
		.pass_on_failure = 0,
	};

	expected_results[DROP_ERR_INNER_MAP]++;
	do_test(type, family, &cmd, DROP_ERR_INNER_MAP);
}

static void test_err_skb_data(int type, sa_family_t family)
{
	expected_results[DROP_ERR_SKB_DATA]++;
	do_test(type, family, NULL, DROP_ERR_SKB_DATA);
}

static void test_err_sk_select_port(int type, sa_family_t family)
{
	struct cmd cmd = {
		.reuseport_index = REUSEPORT_ARRAY_SIZE,
		.pass_on_failure = 0,
	};

	expected_results[DROP_ERR_SK_SELECT_REUSEPORT]++;
	do_test(type, family, &cmd, DROP_ERR_SK_SELECT_REUSEPORT);
}

static void test_pass(int type, sa_family_t family)
{
	struct cmd cmd;
	int i;

	cmd.pass_on_failure = 0;
	for (i = 0; i < REUSEPORT_ARRAY_SIZE; i++) {
		expected_results[PASS]++;
		cmd.reuseport_index = i;
		do_test(type, family, &cmd, PASS);
	}
}

static void test_syncookie(int type, sa_family_t family)
{
	int err, tmp_index = 1;
	struct cmd cmd = {
		.reuseport_index = 0,
		.pass_on_failure = 0,
	};

	if (type != SOCK_STREAM)
		return;

	/*
	 * +1 for TCP-SYN and
	 * +1 for the TCP-ACK (ack the syncookie)
	 */
	expected_results[PASS] += 2;
	enable_syncookie();
	/*
	 * Simulate TCP-SYN and TCP-ACK are handled by two different sk:
	 * TCP-SYN: select sk_fds[tmp_index = 1] tmp_index is from the
	 *          tmp_index_ovr_map
	 * TCP-ACK: select sk_fds[reuseport_index = 0] reuseport_index
	 *          is from the cmd.reuseport_index
	 */
	err = bpf_map_update_elem(tmp_index_ovr_map, &index_zero,
				  &tmp_index, BPF_ANY);
527 528
	RET_IF(err == -1, "update_elem(tmp_index_ovr_map, 0, 1)",
	       "err:%d errno:%d\n", err, errno);
529 530 531
	do_test(type, family, &cmd, PASS);
	err = bpf_map_lookup_elem(tmp_index_ovr_map, &index_zero,
				  &tmp_index);
532 533 534 535
	RET_IF(err == -1 || tmp_index != -1,
	       "lookup_elem(tmp_index_ovr_map)",
	       "err:%d errno:%d tmp_index:%d\n",
	       err, errno, tmp_index);
536 537 538 539 540 541 542 543 544 545 546 547 548 549
	disable_syncookie();
}

static void test_pass_on_err(int type, sa_family_t family)
{
	struct cmd cmd = {
		.reuseport_index = REUSEPORT_ARRAY_SIZE,
		.pass_on_failure = 1,
	};

	expected_results[PASS_ERR_SK_SELECT_REUSEPORT] += 1;
	do_test(type, family, &cmd, PASS_ERR_SK_SELECT_REUSEPORT);
}

550 551 552 553 554 555 556 557 558 559 560
static void test_detach_bpf(int type, sa_family_t family)
{
#ifdef SO_DETACH_REUSEPORT_BPF
	__u32 nr_run_before = 0, nr_run_after = 0, tmp, i;
	struct epoll_event ev;
	int cli_fd, err, nev;
	struct cmd cmd = {};
	int optvalue = 0;

	err = setsockopt(sk_fds[0], SOL_SOCKET, SO_DETACH_REUSEPORT_BPF,
			 &optvalue, sizeof(optvalue));
561 562
	RET_IF(err == -1, "setsockopt(SO_DETACH_REUSEPORT_BPF)",
	       "err:%d errno:%d\n", err, errno);
563 564 565

	err = setsockopt(sk_fds[1], SOL_SOCKET, SO_DETACH_REUSEPORT_BPF,
			 &optvalue, sizeof(optvalue));
566 567 568
	RET_IF(err == 0 || errno != ENOENT,
	       "setsockopt(SO_DETACH_REUSEPORT_BPF)",
	       "err:%d errno:%d\n", err, errno);
569 570 571

	for (i = 0; i < NR_RESULTS; i++) {
		err = bpf_map_lookup_elem(result_map, &i, &tmp);
572 573
		RET_IF(err == -1, "lookup_elem(result_map)",
		       "i:%u err:%d errno:%d\n", i, err, errno);
574 575 576 577
		nr_run_before += tmp;
	}

	cli_fd = send_data(type, family, &cmd, sizeof(cmd), PASS);
578 579
	if (cli_fd < 0)
		return;
580
	nev = epoll_wait(epfd, &ev, 1, 5);
581 582 583
	RET_IF(nev <= 0, "nev <= 0",
	       "nev:%d expected:1 type:%d family:%d data:(0, 0)\n",
	       nev,  type, family);
584 585 586

	for (i = 0; i < NR_RESULTS; i++) {
		err = bpf_map_lookup_elem(result_map, &i, &tmp);
587 588
		RET_IF(err == -1, "lookup_elem(result_map)",
		       "i:%u err:%d errno:%d\n", i, err, errno);
589 590 591
		nr_run_after += tmp;
	}

592 593 594 595
	RET_IF(nr_run_before != nr_run_after,
	       "nr_run_before != nr_run_after",
	       "nr_run_before:%u nr_run_after:%u\n",
	       nr_run_before, nr_run_after);
596 597 598

	close(cli_fd);
#else
599
	test__skip();
600 601 602
#endif
}

603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
static void prepare_sk_fds(int type, sa_family_t family, bool inany)
{
	const int first = REUSEPORT_ARRAY_SIZE - 1;
	int i, err, optval = 1;
	struct epoll_event ev;
	socklen_t addrlen;

	if (inany)
		sa46_init_inany(&srv_sa, family);
	else
		sa46_init_loopback(&srv_sa, family);
	addrlen = sizeof(srv_sa);

	/*
	 * The sk_fds[] is filled from the back such that the order
	 * is exactly opposite to the (struct sock_reuseport *)reuse->socks[].
	 */
	for (i = first; i >= 0; i--) {
		sk_fds[i] = socket(family, type, 0);
622 623
		RET_IF(sk_fds[i] == -1, "socket()", "sk_fds[%d]:%d errno:%d\n",
		       i, sk_fds[i], errno);
624 625
		err = setsockopt(sk_fds[i], SOL_SOCKET, SO_REUSEPORT,
				 &optval, sizeof(optval));
626 627 628
		RET_IF(err == -1, "setsockopt(SO_REUSEPORT)",
		       "sk_fds[%d] err:%d errno:%d\n",
		       i, err, errno);
629 630 631 632 633 634

		if (i == first) {
			err = setsockopt(sk_fds[i], SOL_SOCKET,
					 SO_ATTACH_REUSEPORT_EBPF,
					 &select_by_skb_data_prog,
					 sizeof(select_by_skb_data_prog));
635 636
			RET_IF(err == -1, "setsockopt(SO_ATTACH_REUEPORT_EBPF)",
			       "err:%d errno:%d\n", err, errno);
637 638 639
		}

		err = bind(sk_fds[i], (struct sockaddr *)&srv_sa, addrlen);
640 641
		RET_IF(err == -1, "bind()", "sk_fds[%d] err:%d errno:%d\n",
		       i, err, errno);
642 643 644

		if (type == SOCK_STREAM) {
			err = listen(sk_fds[i], 10);
645 646 647
			RET_IF(err == -1, "listen()",
			       "sk_fds[%d] err:%d errno:%d\n",
			       i, err, errno);
648 649 650 651
		}

		err = bpf_map_update_elem(reuseport_array, &i, &sk_fds[i],
					  BPF_NOEXIST);
652 653
		RET_IF(err == -1, "update_elem(reuseport_array)",
		       "sk_fds[%d] err:%d errno:%d\n", i, err, errno);
654 655 656 657 658 659

		if (i == first) {
			socklen_t addrlen = sizeof(srv_sa);

			err = getsockname(sk_fds[i], (struct sockaddr *)&srv_sa,
					  &addrlen);
660 661
			RET_IF(err == -1, "getsockname()",
			       "sk_fds[%d] err:%d errno:%d\n", i, err, errno);
662 663 664 665
		}
	}

	epfd = epoll_create(1);
666 667
	RET_IF(epfd == -1, "epoll_create(1)",
	       "epfd:%d errno:%d\n", epfd, errno);
668 669 670 671 672

	ev.events = EPOLLIN;
	for (i = 0; i < REUSEPORT_ARRAY_SIZE; i++) {
		ev.data.u32 = i;
		err = epoll_ctl(epfd, EPOLL_CTL_ADD, sk_fds[i], &ev);
673
		RET_IF(err, "epoll_ctl(EPOLL_CTL_ADD)", "sk_fds[%d]\n", i);
674 675 676
	}
}

677 678
static void setup_per_test(int type, sa_family_t family, bool inany,
			   bool no_inner_map)
679 680 681 682 683 684
{
	int ovr = -1, err;

	prepare_sk_fds(type, family, inany);
	err = bpf_map_update_elem(tmp_index_ovr_map, &index_zero, &ovr,
				  BPF_ANY);
685 686
	RET_IF(err == -1, "update_elem(tmp_index_ovr_map, 0, -1)",
	       "err:%d errno:%d\n", err, errno);
687 688 689 690 691 692 693

	/* Install reuseport_array to outer_map? */
	if (no_inner_map)
		return;

	err = bpf_map_update_elem(outer_map, &index_zero, &reuseport_array,
				  BPF_ANY);
694 695
	RET_IF(err == -1, "update_elem(outer_map, 0, reuseport_array)",
	       "err:%d errno:%d\n", err, errno);
696 697
}

698
static void cleanup_per_test(bool no_inner_map)
699 700 701 702 703 704 705
{
	int i, err;

	for (i = 0; i < REUSEPORT_ARRAY_SIZE; i++)
		close(sk_fds[i]);
	close(epfd);

706 707 708 709
	/* Delete reuseport_array from outer_map? */
	if (no_inner_map)
		return;

710
	err = bpf_map_delete_elem(outer_map, &index_zero);
711 712
	RET_IF(err == -1, "delete_elem(outer_map)",
	       "err:%d errno:%d\n", err, errno);
713 714 715 716
}

static void cleanup(void)
{
717 718 719 720 721 722
	if (outer_map != -1)
		close(outer_map);
	if (reuseport_array != -1)
		close(reuseport_array);
	if (obj)
		bpf_object__close(obj);
723 724
}

725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
static const char *family_str(sa_family_t family)
{
	switch (family) {
	case AF_INET:
		return "IPv4";
	case AF_INET6:
		return "IPv6";
	default:
		return "unknown";
	}
}

static const char *sotype_str(int sotype)
{
	switch (sotype) {
	case SOCK_STREAM:
		return "TCP";
	case SOCK_DGRAM:
		return "UDP";
	default:
		return "unknown";
	}
}

749 750
#define TEST_INIT(fn, ...) { fn, #fn, __VA_ARGS__ }

751
static void test_config(int sotype, sa_family_t family, bool inany)
752
{
753 754
	const struct test {
		void (*fn)(int sotype, sa_family_t family);
755
		const char *name;
756 757
		bool no_inner_map;
	} tests[] = {
758 759 760 761 762 763 764
		TEST_INIT(test_err_inner_map, true /* no_inner_map */),
		TEST_INIT(test_err_skb_data),
		TEST_INIT(test_err_sk_select_port),
		TEST_INIT(test_pass),
		TEST_INIT(test_syncookie),
		TEST_INIT(test_pass_on_err),
		TEST_INIT(test_detach_bpf),
765
	};
766
	char s[MAX_TEST_NAME];
767
	const struct test *t;
768

769
	for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
770 771 772 773 774 775 776 777 778
		snprintf(s, sizeof(s), "%s/%s %s %s",
			 family_str(family), sotype_str(sotype),
			 inany ? "INANY" : "LOOPBACK", t->name);

		if (!test__start_subtest(s))
			continue;

		setup_per_test(sotype, family, inany, t->no_inner_map);
		t->fn(sotype, family);
779 780
		cleanup_per_test(t->no_inner_map);
	}
781
}
782

783
#define BIND_INANY true
784

785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
static void test_all(void)
{
	const struct config {
		int sotype;
		sa_family_t family;
		bool inany;
	} configs[] = {
		{ SOCK_STREAM, AF_INET },
		{ SOCK_STREAM, AF_INET, BIND_INANY },
		{ SOCK_STREAM, AF_INET6 },
		{ SOCK_STREAM, AF_INET6, BIND_INANY },
		{ SOCK_DGRAM, AF_INET },
		{ SOCK_DGRAM, AF_INET6 },
	};
	const struct config *c;

	for (c = configs; c < configs + ARRAY_SIZE(configs); c++)
		test_config(c->sotype, c->family, c->inany);
803 804
}

805
void test_select_reuseport(void)
806
{
807 808 809 810 811
	if (create_maps())
		goto out;
	if (prepare_bpf_obj())
		goto out;

812 813
	saved_tcp_fo = read_int_sysctl(TCP_FO_SYSCTL);
	saved_tcp_syncookie = read_int_sysctl(TCP_SYNCOOKIE_SYSCTL);
814 815
	if (saved_tcp_syncookie < 0 || saved_tcp_syncookie < 0)
		goto out;
816

817 818 819 820 821
	if (enable_fastopen())
		goto out;
	if (disable_syncookie())
		goto out;

822
	test_all();
823
out:
824
	cleanup();
825
	restore_sysctls();
826
}