bpf_jit_comp.c 29.9 KB
Newer Older
1
// SPDX-License-Identifier: GPL-2.0-only
2 3
/*
 * bpf_jit_comp.c: BPF JIT compiler
4
 *
5
 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
6
 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
7 8 9
 */
#include <linux/netdevice.h>
#include <linux/filter.h>
10
#include <linux/if_vlan.h>
11 12
#include <linux/bpf.h>

L
Laura Abbott 已提交
13
#include <asm/set_memory.h>
14
#include <asm/nospec-branch.h>
15

16
static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
17 18 19 20 21 22 23 24 25 26 27 28
{
	if (len == 1)
		*ptr = bytes;
	else if (len == 2)
		*(u16 *)ptr = bytes;
	else {
		*(u32 *)ptr = bytes;
		barrier();
	}
	return ptr + len;
}

29 30
#define EMIT(bytes, len) \
	do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
31 32 33 34 35

#define EMIT1(b1)		EMIT(b1, 1)
#define EMIT2(b1, b2)		EMIT((b1) + ((b2) << 8), 2)
#define EMIT3(b1, b2, b3)	EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
36

37
#define EMIT1_off32(b1, off) \
38
	do { EMIT1(b1); EMIT(off, 4); } while (0)
39
#define EMIT2_off32(b1, b2, off) \
40
	do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
41
#define EMIT3_off32(b1, b2, b3, off) \
42
	do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
43
#define EMIT4_off32(b1, b2, b3, b4, off) \
44
	do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
45

46
static bool is_imm8(int value)
47 48 49 50
{
	return value <= 127 && value >= -128;
}

51
static bool is_simm32(s64 value)
52
{
53 54 55 56 57 58
	return value == (s64)(s32)value;
}

static bool is_uimm32(u64 value)
{
	return value == (u64)(u32)value;
59 60
}

61
/* mov dst, src */
62 63 64 65
#define EMIT_mov(DST, SRC)								 \
	do {										 \
		if (DST != SRC)								 \
			EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
	} while (0)

static int bpf_size_to_x86_bytes(int bpf_size)
{
	if (bpf_size == BPF_W)
		return 4;
	else if (bpf_size == BPF_H)
		return 2;
	else if (bpf_size == BPF_B)
		return 1;
	else if (bpf_size == BPF_DW)
		return 4; /* imm32 */
	else
		return 0;
}
81

82 83
/*
 * List of x86 cond jumps opcodes (. + s8)
84 85 86 87 88 89 90 91
 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
 */
#define X86_JB  0x72
#define X86_JAE 0x73
#define X86_JE  0x74
#define X86_JNE 0x75
#define X86_JBE 0x76
#define X86_JA  0x77
92
#define X86_JL  0x7C
93
#define X86_JGE 0x7D
94
#define X86_JLE 0x7E
95
#define X86_JG  0x7F
96

97
/* Pick a register outside of BPF range for JIT internal work */
98
#define AUX_REG (MAX_BPF_JIT_REG + 1)
99

100 101
/*
 * The following table maps BPF registers to x86-64 registers.
102
 *
103
 * x86-64 register R12 is unused, since if used as base address
104 105 106
 * register in load/store instructions, it always needs an
 * extra byte of encoding and is callee saved.
 *
D
Daniel Borkmann 已提交
107 108
 * Also x86-64 register R9 is unused. x86-64 register R10 is
 * used for blinding (if enabled).
109 110
 */
static const int reg2hex[] = {
111 112 113 114 115 116 117 118 119 120 121 122 123
	[BPF_REG_0] = 0,  /* RAX */
	[BPF_REG_1] = 7,  /* RDI */
	[BPF_REG_2] = 6,  /* RSI */
	[BPF_REG_3] = 2,  /* RDX */
	[BPF_REG_4] = 1,  /* RCX */
	[BPF_REG_5] = 0,  /* R8  */
	[BPF_REG_6] = 3,  /* RBX callee saved */
	[BPF_REG_7] = 5,  /* R13 callee saved */
	[BPF_REG_8] = 6,  /* R14 callee saved */
	[BPF_REG_9] = 7,  /* R15 callee saved */
	[BPF_REG_FP] = 5, /* RBP readonly */
	[BPF_REG_AX] = 2, /* R10 temp register */
	[AUX_REG] = 3,    /* R11 temp register */
124 125
};

126 127
/*
 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
128 129 130
 * which need extra byte of encoding.
 * rax,rcx,...,rbp have simpler encoding
 */
131
static bool is_ereg(u32 reg)
132
{
133 134 135 136
	return (1 << reg) & (BIT(BPF_REG_5) |
			     BIT(AUX_REG) |
			     BIT(BPF_REG_7) |
			     BIT(BPF_REG_8) |
137 138
			     BIT(BPF_REG_9) |
			     BIT(BPF_REG_AX));
139 140
}

141 142 143 144 145
static bool is_axreg(u32 reg)
{
	return reg == BPF_REG_0;
}

146
/* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
147
static u8 add_1mod(u8 byte, u32 reg)
148 149 150 151 152 153
{
	if (is_ereg(reg))
		byte |= 1;
	return byte;
}

154
static u8 add_2mod(u8 byte, u32 r1, u32 r2)
155 156 157 158 159 160 161 162
{
	if (is_ereg(r1))
		byte |= 1;
	if (is_ereg(r2))
		byte |= 4;
	return byte;
}

163
/* Encode 'dst_reg' register into x86-64 opcode 'byte' */
164
static u8 add_1reg(u8 byte, u32 dst_reg)
165
{
166
	return byte + reg2hex[dst_reg];
167 168
}

169
/* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
170
static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
171
{
172
	return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
173 174
}

175 176
static void jit_fill_hole(void *area, unsigned int size)
{
177
	/* Fill whole space with INT3 instructions */
178 179 180
	memset(area, 0xcc, size);
}

181
struct jit_context {
182
	int cleanup_addr; /* Epilogue code offset */
183 184
};

185
/* Maximum number of bytes emitted while JITing one eBPF insn */
186 187 188
#define BPF_MAX_INSN_SIZE	128
#define BPF_INSN_SAFETY		64

189
#define PROLOGUE_SIZE		20
190

191 192
/*
 * Emit x86-64 prologue code for BPF program and check its size.
193 194
 * bpf_tail_call helper will skip it while jumping into another program
 */
195
static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
196
{
197 198
	u8 *prog = *pprog;
	int cnt = 0;
199

200 201 202 203 204 205 206 207
	EMIT1(0x55);             /* push rbp */
	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
	/* sub rsp, rounded_stack_depth */
	EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
	EMIT1(0x53);             /* push rbx */
	EMIT2(0x41, 0x55);       /* push r13 */
	EMIT2(0x41, 0x56);       /* push r14 */
	EMIT2(0x41, 0x57);       /* push r15 */
208
	if (!ebpf_from_cbpf) {
209 210
		/* zero init tail_call_cnt */
		EMIT2(0x6a, 0x00);
211 212
		BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
	}
213 214 215
	*pprog = prog;
}

216 217 218
/*
 * Generate the following code:
 *
219 220 221 222 223
 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
 *   if (index >= array->map.max_entries)
 *     goto out;
 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
 *     goto out;
224
 *   prog = array->ptrs[index];
225 226 227 228 229 230 231 232 233 234 235
 *   if (prog == NULL)
 *     goto out;
 *   goto *(prog->bpf_func + prologue_size);
 * out:
 */
static void emit_bpf_tail_call(u8 **pprog)
{
	u8 *prog = *pprog;
	int label1, label2, label3;
	int cnt = 0;

236 237
	/*
	 * rdi - pointer to ctx
238 239 240 241
	 * rsi - pointer to bpf_array
	 * rdx - index in bpf_array
	 */

242 243 244
	/*
	 * if (index >= array->map.max_entries)
	 *	goto out;
245
	 */
246 247
	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
248
	      offsetof(struct bpf_array, map.max_entries));
249
#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */
250 251 252
	EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
	label1 = cnt;

253 254 255
	/*
	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
	 *	goto out;
256
	 */
257
	EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
258
	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
259
#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
260 261 262
	EMIT2(X86_JA, OFFSET2);                   /* ja out */
	label2 = cnt;
	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
263
	EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
264

265
	/* prog = array->ptrs[index]; */
266
	EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
267
		    offsetof(struct bpf_array, ptrs));
268

269 270 271
	/*
	 * if (prog == NULL)
	 *	goto out;
272
	 */
273
	EMIT3(0x48, 0x85, 0xC0);		  /* test rax,rax */
274
#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
275 276 277 278 279 280 281 282
	EMIT2(X86_JE, OFFSET3);                   /* je out */
	label3 = cnt;

	/* goto *(prog->bpf_func + prologue_size); */
	EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
	      offsetof(struct bpf_prog, bpf_func));
	EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */

283 284
	/*
	 * Wow we're ready to jump into next BPF program
285 286 287
	 * rdi == ctx (1st arg)
	 * rax == prog->bpf_func + prologue_size
	 */
288
	RETPOLINE_RAX_BPF_JIT();
289 290 291 292 293 294 295 296

	/* out: */
	BUILD_BUG_ON(cnt - label1 != OFFSET1);
	BUILD_BUG_ON(cnt - label2 != OFFSET2);
	BUILD_BUG_ON(cnt - label3 != OFFSET3);
	*pprog = prog;
}

297 298 299 300 301 302 303
static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
			   u32 dst_reg, const u32 imm32)
{
	u8 *prog = *pprog;
	u8 b1, b2, b3;
	int cnt = 0;

304 305
	/*
	 * Optimization: if imm32 is positive, use 'mov %eax, imm32'
306 307 308 309 310 311 312 313 314 315 316
	 * (which zero-extends imm32) to save 2 bytes.
	 */
	if (sign_propagate && (s32)imm32 < 0) {
		/* 'mov %rax, imm32' sign extends imm32 */
		b1 = add_1mod(0x48, dst_reg);
		b2 = 0xC7;
		b3 = 0xC0;
		EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
		goto done;
	}

317 318
	/*
	 * Optimization: if imm32 is zero, use 'xor %eax, %eax'
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
	 * to save 3 bytes.
	 */
	if (imm32 == 0) {
		if (is_ereg(dst_reg))
			EMIT1(add_2mod(0x40, dst_reg, dst_reg));
		b2 = 0x31; /* xor */
		b3 = 0xC0;
		EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
		goto done;
	}

	/* mov %eax, imm32 */
	if (is_ereg(dst_reg))
		EMIT1(add_1mod(0x40, dst_reg));
	EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
done:
	*pprog = prog;
}

static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
			   const u32 imm32_hi, const u32 imm32_lo)
{
	u8 *prog = *pprog;
	int cnt = 0;

	if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
345 346
		/*
		 * For emitting plain u32, where sign bit must not be
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
		 * propagated LLVM tends to load imm64 over mov32
		 * directly, so save couple of bytes by just doing
		 * 'mov %eax, imm32' instead.
		 */
		emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
	} else {
		/* movabsq %rax, imm64 */
		EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
		EMIT(imm32_lo, 4);
		EMIT(imm32_hi, 4);
	}

	*pprog = prog;
}

362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
{
	u8 *prog = *pprog;
	int cnt = 0;

	if (is64) {
		/* mov dst, src */
		EMIT_mov(dst_reg, src_reg);
	} else {
		/* mov32 dst, src */
		if (is_ereg(dst_reg) || is_ereg(src_reg))
			EMIT1(add_2mod(0x40, dst_reg, src_reg));
		EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
	}

	*pprog = prog;
}

380 381 382 383 384 385 386 387 388 389 390
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
		  int oldproglen, struct jit_context *ctx)
{
	struct bpf_insn *insn = bpf_prog->insnsi;
	int insn_cnt = bpf_prog->len;
	bool seen_exit = false;
	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
	int i, cnt = 0;
	int proglen = 0;
	u8 *prog = temp;

391 392
	emit_prologue(&prog, bpf_prog->aux->stack_depth,
		      bpf_prog_was_classic(bpf_prog));
393

394
	for (i = 0; i < insn_cnt; i++, insn++) {
395 396 397
		const s32 imm32 = insn->imm;
		u32 dst_reg = insn->dst_reg;
		u32 src_reg = insn->src_reg;
398
		u8 b2 = 0, b3 = 0;
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
		s64 jmp_offset;
		u8 jmp_cond;
		int ilen;
		u8 *func;

		switch (insn->code) {
			/* ALU */
		case BPF_ALU | BPF_ADD | BPF_X:
		case BPF_ALU | BPF_SUB | BPF_X:
		case BPF_ALU | BPF_AND | BPF_X:
		case BPF_ALU | BPF_OR | BPF_X:
		case BPF_ALU | BPF_XOR | BPF_X:
		case BPF_ALU64 | BPF_ADD | BPF_X:
		case BPF_ALU64 | BPF_SUB | BPF_X:
		case BPF_ALU64 | BPF_AND | BPF_X:
		case BPF_ALU64 | BPF_OR | BPF_X:
		case BPF_ALU64 | BPF_XOR | BPF_X:
			switch (BPF_OP(insn->code)) {
			case BPF_ADD: b2 = 0x01; break;
			case BPF_SUB: b2 = 0x29; break;
			case BPF_AND: b2 = 0x21; break;
			case BPF_OR: b2 = 0x09; break;
			case BPF_XOR: b2 = 0x31; break;
422
			}
423
			if (BPF_CLASS(insn->code) == BPF_ALU64)
424 425 426 427
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
428
			break;
429

430 431
		case BPF_ALU64 | BPF_MOV | BPF_X:
		case BPF_ALU | BPF_MOV | BPF_X:
432 433 434
			emit_mov_reg(&prog,
				     BPF_CLASS(insn->code) == BPF_ALU64,
				     dst_reg, src_reg);
435
			break;
436

437
			/* neg dst */
438 439 440
		case BPF_ALU | BPF_NEG:
		case BPF_ALU64 | BPF_NEG:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
441 442 443 444
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
			EMIT2(0xF7, add_1reg(0xD8, dst_reg));
445 446 447 448 449 450 451 452 453 454 455 456 457
			break;

		case BPF_ALU | BPF_ADD | BPF_K:
		case BPF_ALU | BPF_SUB | BPF_K:
		case BPF_ALU | BPF_AND | BPF_K:
		case BPF_ALU | BPF_OR | BPF_K:
		case BPF_ALU | BPF_XOR | BPF_K:
		case BPF_ALU64 | BPF_ADD | BPF_K:
		case BPF_ALU64 | BPF_SUB | BPF_K:
		case BPF_ALU64 | BPF_AND | BPF_K:
		case BPF_ALU64 | BPF_OR | BPF_K:
		case BPF_ALU64 | BPF_XOR | BPF_K:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
458 459 460
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
461

462 463
			/*
			 * b3 holds 'normal' opcode, b2 short form only valid
464 465
			 * in case dst is eax/rax.
			 */
466
			switch (BPF_OP(insn->code)) {
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
			case BPF_ADD:
				b3 = 0xC0;
				b2 = 0x05;
				break;
			case BPF_SUB:
				b3 = 0xE8;
				b2 = 0x2D;
				break;
			case BPF_AND:
				b3 = 0xE0;
				b2 = 0x25;
				break;
			case BPF_OR:
				b3 = 0xC8;
				b2 = 0x0D;
				break;
			case BPF_XOR:
				b3 = 0xF0;
				b2 = 0x35;
				break;
487 488
			}

489 490
			if (is_imm8(imm32))
				EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
491 492
			else if (is_axreg(dst_reg))
				EMIT1_off32(b2, imm32);
493
			else
494
				EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
495 496 497 498
			break;

		case BPF_ALU64 | BPF_MOV | BPF_K:
		case BPF_ALU | BPF_MOV | BPF_K:
499 500
			emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
				       dst_reg, imm32);
501 502
			break;

503
		case BPF_LD | BPF_IMM | BPF_DW:
504
			emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
505 506 507 508
			insn++;
			i++;
			break;

509
			/* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
510 511 512 513 514 515 516 517 518 519 520 521
		case BPF_ALU | BPF_MOD | BPF_X:
		case BPF_ALU | BPF_DIV | BPF_X:
		case BPF_ALU | BPF_MOD | BPF_K:
		case BPF_ALU | BPF_DIV | BPF_K:
		case BPF_ALU64 | BPF_MOD | BPF_X:
		case BPF_ALU64 | BPF_DIV | BPF_X:
		case BPF_ALU64 | BPF_MOD | BPF_K:
		case BPF_ALU64 | BPF_DIV | BPF_K:
			EMIT1(0x50); /* push rax */
			EMIT1(0x52); /* push rdx */

			if (BPF_SRC(insn->code) == BPF_X)
522 523
				/* mov r11, src_reg */
				EMIT_mov(AUX_REG, src_reg);
524
			else
525 526
				/* mov r11, imm32 */
				EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
527

528 529
			/* mov rax, dst_reg */
			EMIT_mov(BPF_REG_0, dst_reg);
530

531 532
			/*
			 * xor edx, edx
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
			 * equivalent to 'xor rdx, rdx', but one byte less
			 */
			EMIT2(0x31, 0xd2);

			if (BPF_CLASS(insn->code) == BPF_ALU64)
				/* div r11 */
				EMIT3(0x49, 0xF7, 0xF3);
			else
				/* div r11d */
				EMIT3(0x41, 0xF7, 0xF3);

			if (BPF_OP(insn->code) == BPF_MOD)
				/* mov r11, rdx */
				EMIT3(0x49, 0x89, 0xD3);
			else
				/* mov r11, rax */
				EMIT3(0x49, 0x89, 0xC3);

			EMIT1(0x5A); /* pop rdx */
			EMIT1(0x58); /* pop rax */

554 555
			/* mov dst_reg, r11 */
			EMIT_mov(dst_reg, AUX_REG);
556 557 558 559 560 561
			break;

		case BPF_ALU | BPF_MUL | BPF_K:
		case BPF_ALU | BPF_MUL | BPF_X:
		case BPF_ALU64 | BPF_MUL | BPF_K:
		case BPF_ALU64 | BPF_MUL | BPF_X:
562 563 564
		{
			bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;

565 566 567 568
			if (dst_reg != BPF_REG_0)
				EMIT1(0x50); /* push rax */
			if (dst_reg != BPF_REG_3)
				EMIT1(0x52); /* push rdx */
569

570 571
			/* mov r11, dst_reg */
			EMIT_mov(AUX_REG, dst_reg);
572 573

			if (BPF_SRC(insn->code) == BPF_X)
574
				emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
575
			else
576
				emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
577

578
			if (is64)
579 580 581 582 583 584
				EMIT1(add_1mod(0x48, AUX_REG));
			else if (is_ereg(AUX_REG))
				EMIT1(add_1mod(0x40, AUX_REG));
			/* mul(q) r11 */
			EMIT2(0xF7, add_1reg(0xE0, AUX_REG));

585 586 587 588 589 590 591
			if (dst_reg != BPF_REG_3)
				EMIT1(0x5A); /* pop rdx */
			if (dst_reg != BPF_REG_0) {
				/* mov dst_reg, rax */
				EMIT_mov(dst_reg, BPF_REG_0);
				EMIT1(0x58); /* pop rax */
			}
592
			break;
593
		}
594
			/* Shifts */
595 596 597 598 599 600 601
		case BPF_ALU | BPF_LSH | BPF_K:
		case BPF_ALU | BPF_RSH | BPF_K:
		case BPF_ALU | BPF_ARSH | BPF_K:
		case BPF_ALU64 | BPF_LSH | BPF_K:
		case BPF_ALU64 | BPF_RSH | BPF_K:
		case BPF_ALU64 | BPF_ARSH | BPF_K:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
602 603 604
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
605 606 607 608 609 610

			switch (BPF_OP(insn->code)) {
			case BPF_LSH: b3 = 0xE0; break;
			case BPF_RSH: b3 = 0xE8; break;
			case BPF_ARSH: b3 = 0xF8; break;
			}
611 612 613 614 615

			if (imm32 == 1)
				EMIT2(0xD1, add_1reg(b3, dst_reg));
			else
				EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
616 617
			break;

618 619 620 621 622 623 624
		case BPF_ALU | BPF_LSH | BPF_X:
		case BPF_ALU | BPF_RSH | BPF_X:
		case BPF_ALU | BPF_ARSH | BPF_X:
		case BPF_ALU64 | BPF_LSH | BPF_X:
		case BPF_ALU64 | BPF_RSH | BPF_X:
		case BPF_ALU64 | BPF_ARSH | BPF_X:

625
			/* Check for bad case when dst_reg == rcx */
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
			if (dst_reg == BPF_REG_4) {
				/* mov r11, dst_reg */
				EMIT_mov(AUX_REG, dst_reg);
				dst_reg = AUX_REG;
			}

			if (src_reg != BPF_REG_4) { /* common case */
				EMIT1(0x51); /* push rcx */

				/* mov rcx, src_reg */
				EMIT_mov(BPF_REG_4, src_reg);
			}

			/* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
			if (BPF_CLASS(insn->code) == BPF_ALU64)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));

			switch (BPF_OP(insn->code)) {
			case BPF_LSH: b3 = 0xE0; break;
			case BPF_RSH: b3 = 0xE8; break;
			case BPF_ARSH: b3 = 0xF8; break;
			}
			EMIT2(0xD3, add_1reg(b3, dst_reg));

			if (src_reg != BPF_REG_4)
				EMIT1(0x59); /* pop rcx */

			if (insn->dst_reg == BPF_REG_4)
				/* mov dst_reg, r11 */
				EMIT_mov(insn->dst_reg, AUX_REG);
			break;

660
		case BPF_ALU | BPF_END | BPF_FROM_BE:
661
			switch (imm32) {
662
			case 16:
663
				/* Emit 'ror %ax, 8' to swap lower 2 bytes */
664
				EMIT1(0x66);
665
				if (is_ereg(dst_reg))
666
					EMIT1(0x41);
667
				EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
668

669
				/* Emit 'movzwl eax, ax' */
670 671 672 673 674
				if (is_ereg(dst_reg))
					EMIT3(0x45, 0x0F, 0xB7);
				else
					EMIT2(0x0F, 0xB7);
				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
675 676
				break;
			case 32:
677
				/* Emit 'bswap eax' to swap lower 4 bytes */
678
				if (is_ereg(dst_reg))
679
					EMIT2(0x41, 0x0F);
680
				else
681
					EMIT1(0x0F);
682
				EMIT1(add_1reg(0xC8, dst_reg));
683
				break;
684
			case 64:
685
				/* Emit 'bswap rax' to swap 8 bytes */
686 687
				EMIT3(add_1mod(0x48, dst_reg), 0x0F,
				      add_1reg(0xC8, dst_reg));
688 689
				break;
			}
690 691 692
			break;

		case BPF_ALU | BPF_END | BPF_FROM_LE:
693 694
			switch (imm32) {
			case 16:
695 696
				/*
				 * Emit 'movzwl eax, ax' to zero extend 16-bit
697 698 699 700 701 702 703 704 705
				 * into 64 bit
				 */
				if (is_ereg(dst_reg))
					EMIT3(0x45, 0x0F, 0xB7);
				else
					EMIT2(0x0F, 0xB7);
				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
				break;
			case 32:
706
				/* Emit 'mov eax, eax' to clear upper 32-bits */
707 708 709 710 711 712 713 714
				if (is_ereg(dst_reg))
					EMIT1(0x45);
				EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
				break;
			case 64:
				/* nop */
				break;
			}
715 716
			break;

717
			/* ST: *(u8*)(dst_reg + off) = imm */
718
		case BPF_ST | BPF_MEM | BPF_B:
719
			if (is_ereg(dst_reg))
720 721 722 723 724
				EMIT2(0x41, 0xC6);
			else
				EMIT1(0xC6);
			goto st;
		case BPF_ST | BPF_MEM | BPF_H:
725
			if (is_ereg(dst_reg))
726 727 728 729 730
				EMIT3(0x66, 0x41, 0xC7);
			else
				EMIT2(0x66, 0xC7);
			goto st;
		case BPF_ST | BPF_MEM | BPF_W:
731
			if (is_ereg(dst_reg))
732 733 734 735 736
				EMIT2(0x41, 0xC7);
			else
				EMIT1(0xC7);
			goto st;
		case BPF_ST | BPF_MEM | BPF_DW:
737
			EMIT2(add_1mod(0x48, dst_reg), 0xC7);
738 739

st:			if (is_imm8(insn->off))
740
				EMIT2(add_1reg(0x40, dst_reg), insn->off);
741
			else
742
				EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
743

744
			EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
745 746
			break;

747
			/* STX: *(u8*)(dst_reg + off) = src_reg */
748
		case BPF_STX | BPF_MEM | BPF_B:
749
			/* Emit 'mov byte ptr [rax + off], al' */
750
			if (is_ereg(dst_reg) || is_ereg(src_reg) ||
751
			    /* We have to add extra byte for x86 SIL, DIL regs */
752 753
			    src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
754 755 756 757
			else
				EMIT1(0x88);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_H:
758 759
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
760 761 762 763
			else
				EMIT2(0x66, 0x89);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_W:
764 765
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
766 767 768 769
			else
				EMIT1(0x89);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_DW:
770
			EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
771
stx:			if (is_imm8(insn->off))
772
				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
773
			else
774
				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
775 776 777
					    insn->off);
			break;

778
			/* LDX: dst_reg = *(u8*)(src_reg + off) */
779
		case BPF_LDX | BPF_MEM | BPF_B:
780
			/* Emit 'movzx rax, byte ptr [rax + off]' */
781
			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
782 783
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_H:
784
			/* Emit 'movzx rax, word ptr [rax + off]' */
785
			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
786 787
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_W:
788
			/* Emit 'mov eax, dword ptr [rax+0x14]' */
789 790
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
791 792 793 794
			else
				EMIT1(0x8B);
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_DW:
795
			/* Emit 'mov rax, qword ptr [rax+0x14]' */
796
			EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
797 798 799
ldx:			/*
			 * If insn->off == 0 we can save one extra byte, but
			 * special case of x86 R13 which always needs an offset
800 801 802
			 * is not worth the hassle
			 */
			if (is_imm8(insn->off))
803
				EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
804
			else
805
				EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
806 807 808
					    insn->off);
			break;

809
			/* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
810
		case BPF_STX | BPF_XADD | BPF_W:
811
			/* Emit 'lock add dword ptr [rax + off], eax' */
812 813
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
814 815 816 817
			else
				EMIT2(0xF0, 0x01);
			goto xadd;
		case BPF_STX | BPF_XADD | BPF_DW:
818
			EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
819
xadd:			if (is_imm8(insn->off))
820
				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
821
			else
822
				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
823 824 825 826 827
					    insn->off);
			break;

			/* call */
		case BPF_JMP | BPF_CALL:
828
			func = (u8 *) __bpf_call_base + imm32;
829
			jmp_offset = func - (image + addrs[i]);
830
			if (!imm32 || !is_simm32(jmp_offset)) {
831
				pr_err("unsupported BPF func %d addr %p image %p\n",
832
				       imm32, func, image);
833 834 835 836 837
				return -EINVAL;
			}
			EMIT1_off32(0xE8, jmp_offset);
			break;

838
		case BPF_JMP | BPF_TAIL_CALL:
839 840 841
			emit_bpf_tail_call(&prog);
			break;

842 843 844 845
			/* cond jump */
		case BPF_JMP | BPF_JEQ | BPF_X:
		case BPF_JMP | BPF_JNE | BPF_X:
		case BPF_JMP | BPF_JGT | BPF_X:
846
		case BPF_JMP | BPF_JLT | BPF_X:
847
		case BPF_JMP | BPF_JGE | BPF_X:
848
		case BPF_JMP | BPF_JLE | BPF_X:
849
		case BPF_JMP | BPF_JSGT | BPF_X:
850
		case BPF_JMP | BPF_JSLT | BPF_X:
851
		case BPF_JMP | BPF_JSGE | BPF_X:
852
		case BPF_JMP | BPF_JSLE | BPF_X:
853 854 855 856 857 858 859 860 861 862
		case BPF_JMP32 | BPF_JEQ | BPF_X:
		case BPF_JMP32 | BPF_JNE | BPF_X:
		case BPF_JMP32 | BPF_JGT | BPF_X:
		case BPF_JMP32 | BPF_JLT | BPF_X:
		case BPF_JMP32 | BPF_JGE | BPF_X:
		case BPF_JMP32 | BPF_JLE | BPF_X:
		case BPF_JMP32 | BPF_JSGT | BPF_X:
		case BPF_JMP32 | BPF_JSLT | BPF_X:
		case BPF_JMP32 | BPF_JSGE | BPF_X:
		case BPF_JMP32 | BPF_JSLE | BPF_X:
863
			/* cmp dst_reg, src_reg */
864 865 866 867 868
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
869 870 871
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JSET | BPF_X:
872
		case BPF_JMP32 | BPF_JSET | BPF_X:
873
			/* test dst_reg, src_reg */
874 875 876 877 878
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
879 880 881
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JSET | BPF_K:
882
		case BPF_JMP32 | BPF_JSET | BPF_K:
883
			/* test dst_reg, imm32 */
884 885 886 887
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
888
			EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
889 890 891 892 893
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JEQ | BPF_K:
		case BPF_JMP | BPF_JNE | BPF_K:
		case BPF_JMP | BPF_JGT | BPF_K:
894
		case BPF_JMP | BPF_JLT | BPF_K:
895
		case BPF_JMP | BPF_JGE | BPF_K:
896
		case BPF_JMP | BPF_JLE | BPF_K:
897
		case BPF_JMP | BPF_JSGT | BPF_K:
898
		case BPF_JMP | BPF_JSLT | BPF_K:
899
		case BPF_JMP | BPF_JSGE | BPF_K:
900
		case BPF_JMP | BPF_JSLE | BPF_K:
901 902 903 904 905 906 907 908 909 910
		case BPF_JMP32 | BPF_JEQ | BPF_K:
		case BPF_JMP32 | BPF_JNE | BPF_K:
		case BPF_JMP32 | BPF_JGT | BPF_K:
		case BPF_JMP32 | BPF_JLT | BPF_K:
		case BPF_JMP32 | BPF_JGE | BPF_K:
		case BPF_JMP32 | BPF_JLE | BPF_K:
		case BPF_JMP32 | BPF_JSGT | BPF_K:
		case BPF_JMP32 | BPF_JSLT | BPF_K:
		case BPF_JMP32 | BPF_JSGE | BPF_K:
		case BPF_JMP32 | BPF_JSLE | BPF_K:
911
			/* cmp dst_reg, imm8/32 */
912 913 914 915
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
916

917 918
			if (is_imm8(imm32))
				EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
919
			else
920
				EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
921

922
emit_cond_jmp:		/* Convert BPF opcode to x86 */
923 924 925 926 927 928 929 930 931 932 933 934
			switch (BPF_OP(insn->code)) {
			case BPF_JEQ:
				jmp_cond = X86_JE;
				break;
			case BPF_JSET:
			case BPF_JNE:
				jmp_cond = X86_JNE;
				break;
			case BPF_JGT:
				/* GT is unsigned '>', JA in x86 */
				jmp_cond = X86_JA;
				break;
935 936 937 938
			case BPF_JLT:
				/* LT is unsigned '<', JB in x86 */
				jmp_cond = X86_JB;
				break;
939 940 941 942
			case BPF_JGE:
				/* GE is unsigned '>=', JAE in x86 */
				jmp_cond = X86_JAE;
				break;
943 944 945 946
			case BPF_JLE:
				/* LE is unsigned '<=', JBE in x86 */
				jmp_cond = X86_JBE;
				break;
947
			case BPF_JSGT:
948
				/* Signed '>', GT in x86 */
949 950
				jmp_cond = X86_JG;
				break;
951
			case BPF_JSLT:
952
				/* Signed '<', LT in x86 */
953 954
				jmp_cond = X86_JL;
				break;
955
			case BPF_JSGE:
956
				/* Signed '>=', GE in x86 */
957 958
				jmp_cond = X86_JGE;
				break;
959
			case BPF_JSLE:
960
				/* Signed '<=', LE in x86 */
961 962
				jmp_cond = X86_JLE;
				break;
963
			default: /* to silence GCC warning */
964 965 966 967 968 969 970 971 972 973 974 975 976
				return -EFAULT;
			}
			jmp_offset = addrs[i + insn->off] - addrs[i];
			if (is_imm8(jmp_offset)) {
				EMIT2(jmp_cond, jmp_offset);
			} else if (is_simm32(jmp_offset)) {
				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
			} else {
				pr_err("cond_jmp gen bug %llx\n", jmp_offset);
				return -EFAULT;
			}

			break;
977

978
		case BPF_JMP | BPF_JA:
979 980 981 982 983 984 985 986 987 988 989
			if (insn->off == -1)
				/* -1 jmp instructions will always jump
				 * backwards two bytes. Explicitly handling
				 * this case avoids wasting too many passes
				 * when there are long sequences of replaced
				 * dead code.
				 */
				jmp_offset = -2;
			else
				jmp_offset = addrs[i + insn->off] - addrs[i];

990
			if (!jmp_offset)
991
				/* Optimize out nop jumps */
992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004
				break;
emit_jmp:
			if (is_imm8(jmp_offset)) {
				EMIT2(0xEB, jmp_offset);
			} else if (is_simm32(jmp_offset)) {
				EMIT1_off32(0xE9, jmp_offset);
			} else {
				pr_err("jmp gen bug %llx\n", jmp_offset);
				return -EFAULT;
			}
			break;

		case BPF_JMP | BPF_EXIT:
1005
			if (seen_exit) {
1006 1007 1008
				jmp_offset = ctx->cleanup_addr - addrs[i];
				goto emit_jmp;
			}
1009
			seen_exit = true;
1010
			/* Update cleanup_addr */
1011
			ctx->cleanup_addr = proglen;
1012 1013 1014 1015 1016 1017 1018 1019
			if (!bpf_prog_was_classic(bpf_prog))
				EMIT1(0x5B); /* get rid of tail_call_cnt */
			EMIT2(0x41, 0x5F);   /* pop r15 */
			EMIT2(0x41, 0x5E);   /* pop r14 */
			EMIT2(0x41, 0x5D);   /* pop r13 */
			EMIT1(0x5B);         /* pop rbx */
			EMIT1(0xC9);         /* leave */
			EMIT1(0xC3);         /* ret */
1020 1021
			break;

1022
		default:
1023 1024
			/*
			 * By design x86-64 JIT should support all BPF instructions.
1025
			 * This error will be seen if new instruction was added
1026 1027
			 * to the interpreter, but not to the JIT, or if there is
			 * junk in bpf_prog.
1028 1029
			 */
			pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1030 1031
			return -EINVAL;
		}
1032

1033
		ilen = prog - temp;
1034
		if (ilen > BPF_MAX_INSN_SIZE) {
1035
			pr_err("bpf_jit: fatal insn size error\n");
1036 1037 1038
			return -EFAULT;
		}

1039 1040
		if (image) {
			if (unlikely(proglen + ilen > oldproglen)) {
1041
				pr_err("bpf_jit: fatal error\n");
1042
				return -EFAULT;
1043
			}
1044
			memcpy(image + proglen, temp, ilen);
1045
		}
1046 1047 1048 1049 1050 1051 1052
		proglen += ilen;
		addrs[i] = proglen;
		prog = temp;
	}
	return proglen;
}

1053 1054 1055 1056 1057 1058 1059 1060
struct x64_jit_data {
	struct bpf_binary_header *header;
	int *addrs;
	u8 *image;
	int proglen;
	struct jit_context ctx;
};

1061
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1062 1063
{
	struct bpf_binary_header *header = NULL;
1064
	struct bpf_prog *tmp, *orig_prog = prog;
1065
	struct x64_jit_data *jit_data;
1066 1067
	int proglen, oldproglen = 0;
	struct jit_context ctx = {};
1068
	bool tmp_blinded = false;
1069
	bool extra_pass = false;
1070 1071 1072 1073 1074
	u8 *image = NULL;
	int *addrs;
	int pass;
	int i;

1075
	if (!prog->jit_requested)
1076 1077 1078
		return orig_prog;

	tmp = bpf_jit_blind_constants(prog);
1079 1080
	/*
	 * If blinding was requested and we failed during blinding,
1081 1082 1083 1084 1085 1086 1087 1088
	 * we must fall back to the interpreter.
	 */
	if (IS_ERR(tmp))
		return orig_prog;
	if (tmp != prog) {
		tmp_blinded = true;
		prog = tmp;
	}
1089

1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107
	jit_data = prog->aux->jit_data;
	if (!jit_data) {
		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
		if (!jit_data) {
			prog = orig_prog;
			goto out;
		}
		prog->aux->jit_data = jit_data;
	}
	addrs = jit_data->addrs;
	if (addrs) {
		ctx = jit_data->ctx;
		oldproglen = jit_data->proglen;
		image = jit_data->image;
		header = jit_data->header;
		extra_pass = true;
		goto skip_init_addrs;
	}
1108
	addrs = kmalloc_array(prog->len, sizeof(*addrs), GFP_KERNEL);
1109 1110
	if (!addrs) {
		prog = orig_prog;
1111
		goto out_addrs;
1112
	}
1113

1114 1115 1116
	/*
	 * Before first pass, make a rough estimation of addrs[]
	 * each BPF instruction is translated to less than 64 bytes
1117 1118 1119 1120 1121 1122
	 */
	for (proglen = 0, i = 0; i < prog->len; i++) {
		proglen += 64;
		addrs[i] = proglen;
	}
	ctx.cleanup_addr = proglen;
1123
skip_init_addrs:
1124

1125 1126 1127
	/*
	 * JITed image shrinks with every pass and the loop iterates
	 * until the image stops shrinking. Very large BPF programs
1128
	 * may converge on the last pass. In such case do one more
1129
	 * pass to emit the final image.
1130
	 */
1131
	for (pass = 0; pass < 20 || image; pass++) {
1132 1133
		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
		if (proglen <= 0) {
1134
out_image:
1135 1136
			image = NULL;
			if (header)
1137
				bpf_jit_binary_free(header);
1138 1139
			prog = orig_prog;
			goto out_addrs;
1140
		}
1141
		if (image) {
1142
			if (proglen != oldproglen) {
1143 1144
				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
				       proglen, oldproglen);
1145
				goto out_image;
1146
			}
1147 1148 1149
			break;
		}
		if (proglen == oldproglen) {
1150 1151
			header = bpf_jit_binary_alloc(proglen, &image,
						      1, jit_fill_hole);
1152 1153 1154 1155
			if (!header) {
				prog = orig_prog;
				goto out_addrs;
			}
1156 1157
		}
		oldproglen = proglen;
1158
		cond_resched();
1159
	}
1160

1161
	if (bpf_jit_enable > 1)
1162
		bpf_jit_dump(prog->len, proglen, pass + 1, image);
1163 1164

	if (image) {
1165 1166 1167 1168 1169 1170 1171 1172 1173
		if (!prog->is_func || extra_pass) {
			bpf_jit_binary_lock_ro(header);
		} else {
			jit_data->addrs = addrs;
			jit_data->ctx = ctx;
			jit_data->proglen = proglen;
			jit_data->image = image;
			jit_data->header = header;
		}
1174
		prog->bpf_func = (void *)image;
1175
		prog->jited = 1;
1176
		prog->jited_len = proglen;
1177 1178
	} else {
		prog = orig_prog;
1179
	}
1180

1181
	if (!image || !prog->is_func || extra_pass) {
M
Martin KaFai Lau 已提交
1182 1183
		if (image)
			bpf_prog_fill_jited_linfo(prog, addrs);
1184
out_addrs:
1185 1186 1187 1188
		kfree(addrs);
		kfree(jit_data);
		prog->aux->jit_data = NULL;
	}
1189 1190 1191 1192
out:
	if (tmp_blinded)
		bpf_jit_prog_release_other(prog, prog == orig_prog ?
					   tmp : orig_prog);
1193
	return prog;
1194
}