bpf_jit_comp.c 30.0 KB
Newer Older
1
// SPDX-License-Identifier: GPL-2.0-only
2 3
/*
 * bpf_jit_comp.c: BPF JIT compiler
4
 *
5
 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
6
 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
7 8 9
 */
#include <linux/netdevice.h>
#include <linux/filter.h>
10
#include <linux/if_vlan.h>
11 12
#include <linux/bpf.h>

L
Laura Abbott 已提交
13
#include <asm/set_memory.h>
14
#include <asm/nospec-branch.h>
15

16
static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
17 18 19 20 21 22 23 24 25 26 27 28
{
	if (len == 1)
		*ptr = bytes;
	else if (len == 2)
		*(u16 *)ptr = bytes;
	else {
		*(u32 *)ptr = bytes;
		barrier();
	}
	return ptr + len;
}

29 30
#define EMIT(bytes, len) \
	do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
31 32 33 34 35

#define EMIT1(b1)		EMIT(b1, 1)
#define EMIT2(b1, b2)		EMIT((b1) + ((b2) << 8), 2)
#define EMIT3(b1, b2, b3)	EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
36

37
#define EMIT1_off32(b1, off) \
38
	do { EMIT1(b1); EMIT(off, 4); } while (0)
39
#define EMIT2_off32(b1, b2, off) \
40
	do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
41
#define EMIT3_off32(b1, b2, b3, off) \
42
	do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
43
#define EMIT4_off32(b1, b2, b3, b4, off) \
44
	do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
45

46
static bool is_imm8(int value)
47 48 49 50
{
	return value <= 127 && value >= -128;
}

51
static bool is_simm32(s64 value)
52
{
53 54 55 56 57 58
	return value == (s64)(s32)value;
}

static bool is_uimm32(u64 value)
{
	return value == (u64)(u32)value;
59 60
}

61
/* mov dst, src */
62 63 64 65
#define EMIT_mov(DST, SRC)								 \
	do {										 \
		if (DST != SRC)								 \
			EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
	} while (0)

static int bpf_size_to_x86_bytes(int bpf_size)
{
	if (bpf_size == BPF_W)
		return 4;
	else if (bpf_size == BPF_H)
		return 2;
	else if (bpf_size == BPF_B)
		return 1;
	else if (bpf_size == BPF_DW)
		return 4; /* imm32 */
	else
		return 0;
}
81

82 83
/*
 * List of x86 cond jumps opcodes (. + s8)
84 85 86 87 88 89 90 91
 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
 */
#define X86_JB  0x72
#define X86_JAE 0x73
#define X86_JE  0x74
#define X86_JNE 0x75
#define X86_JBE 0x76
#define X86_JA  0x77
92
#define X86_JL  0x7C
93
#define X86_JGE 0x7D
94
#define X86_JLE 0x7E
95
#define X86_JG  0x7F
96

97
/* Pick a register outside of BPF range for JIT internal work */
98
#define AUX_REG (MAX_BPF_JIT_REG + 1)
99

100 101
/*
 * The following table maps BPF registers to x86-64 registers.
102
 *
103
 * x86-64 register R12 is unused, since if used as base address
104 105 106
 * register in load/store instructions, it always needs an
 * extra byte of encoding and is callee saved.
 *
D
Daniel Borkmann 已提交
107 108
 * Also x86-64 register R9 is unused. x86-64 register R10 is
 * used for blinding (if enabled).
109 110
 */
static const int reg2hex[] = {
111 112 113 114 115 116 117 118 119 120 121 122 123
	[BPF_REG_0] = 0,  /* RAX */
	[BPF_REG_1] = 7,  /* RDI */
	[BPF_REG_2] = 6,  /* RSI */
	[BPF_REG_3] = 2,  /* RDX */
	[BPF_REG_4] = 1,  /* RCX */
	[BPF_REG_5] = 0,  /* R8  */
	[BPF_REG_6] = 3,  /* RBX callee saved */
	[BPF_REG_7] = 5,  /* R13 callee saved */
	[BPF_REG_8] = 6,  /* R14 callee saved */
	[BPF_REG_9] = 7,  /* R15 callee saved */
	[BPF_REG_FP] = 5, /* RBP readonly */
	[BPF_REG_AX] = 2, /* R10 temp register */
	[AUX_REG] = 3,    /* R11 temp register */
124 125
};

126 127
/*
 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
128 129 130
 * which need extra byte of encoding.
 * rax,rcx,...,rbp have simpler encoding
 */
131
static bool is_ereg(u32 reg)
132
{
133 134 135 136
	return (1 << reg) & (BIT(BPF_REG_5) |
			     BIT(AUX_REG) |
			     BIT(BPF_REG_7) |
			     BIT(BPF_REG_8) |
137 138
			     BIT(BPF_REG_9) |
			     BIT(BPF_REG_AX));
139 140
}

141 142 143 144 145
static bool is_axreg(u32 reg)
{
	return reg == BPF_REG_0;
}

146
/* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
147
static u8 add_1mod(u8 byte, u32 reg)
148 149 150 151 152 153
{
	if (is_ereg(reg))
		byte |= 1;
	return byte;
}

154
static u8 add_2mod(u8 byte, u32 r1, u32 r2)
155 156 157 158 159 160 161 162
{
	if (is_ereg(r1))
		byte |= 1;
	if (is_ereg(r2))
		byte |= 4;
	return byte;
}

163
/* Encode 'dst_reg' register into x86-64 opcode 'byte' */
164
static u8 add_1reg(u8 byte, u32 dst_reg)
165
{
166
	return byte + reg2hex[dst_reg];
167 168
}

169
/* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
170
static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
171
{
172
	return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
173 174
}

175 176
static void jit_fill_hole(void *area, unsigned int size)
{
177
	/* Fill whole space with INT3 instructions */
178 179 180
	memset(area, 0xcc, size);
}

181
struct jit_context {
182
	int cleanup_addr; /* Epilogue code offset */
183 184
};

185
/* Maximum number of bytes emitted while JITing one eBPF insn */
186 187 188
#define BPF_MAX_INSN_SIZE	128
#define BPF_INSN_SAFETY		64

189
#define PROLOGUE_SIZE		20
190

191 192
/*
 * Emit x86-64 prologue code for BPF program and check its size.
193 194
 * bpf_tail_call helper will skip it while jumping into another program
 */
195
static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
196
{
197 198
	u8 *prog = *pprog;
	int cnt = 0;
199

200 201 202 203 204 205 206 207
	EMIT1(0x55);             /* push rbp */
	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
	/* sub rsp, rounded_stack_depth */
	EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
	EMIT1(0x53);             /* push rbx */
	EMIT2(0x41, 0x55);       /* push r13 */
	EMIT2(0x41, 0x56);       /* push r14 */
	EMIT2(0x41, 0x57);       /* push r15 */
208
	if (!ebpf_from_cbpf) {
209 210
		/* zero init tail_call_cnt */
		EMIT2(0x6a, 0x00);
211 212
		BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
	}
213 214 215
	*pprog = prog;
}

216 217 218
/*
 * Generate the following code:
 *
219 220 221 222 223
 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
 *   if (index >= array->map.max_entries)
 *     goto out;
 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
 *     goto out;
224
 *   prog = array->ptrs[index];
225 226 227 228 229 230 231 232 233 234 235
 *   if (prog == NULL)
 *     goto out;
 *   goto *(prog->bpf_func + prologue_size);
 * out:
 */
static void emit_bpf_tail_call(u8 **pprog)
{
	u8 *prog = *pprog;
	int label1, label2, label3;
	int cnt = 0;

236 237
	/*
	 * rdi - pointer to ctx
238 239 240 241
	 * rsi - pointer to bpf_array
	 * rdx - index in bpf_array
	 */

242 243 244
	/*
	 * if (index >= array->map.max_entries)
	 *	goto out;
245
	 */
246 247
	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
248
	      offsetof(struct bpf_array, map.max_entries));
249
#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */
250 251 252
	EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
	label1 = cnt;

253 254 255
	/*
	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
	 *	goto out;
256
	 */
257
	EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
258
	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
259
#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
260 261 262
	EMIT2(X86_JA, OFFSET2);                   /* ja out */
	label2 = cnt;
	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
263
	EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
264

265
	/* prog = array->ptrs[index]; */
266
	EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
267
		    offsetof(struct bpf_array, ptrs));
268

269 270 271
	/*
	 * if (prog == NULL)
	 *	goto out;
272
	 */
273
	EMIT3(0x48, 0x85, 0xC0);		  /* test rax,rax */
274
#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
275 276 277 278 279 280 281 282
	EMIT2(X86_JE, OFFSET3);                   /* je out */
	label3 = cnt;

	/* goto *(prog->bpf_func + prologue_size); */
	EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
	      offsetof(struct bpf_prog, bpf_func));
	EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */

283 284
	/*
	 * Wow we're ready to jump into next BPF program
285 286 287
	 * rdi == ctx (1st arg)
	 * rax == prog->bpf_func + prologue_size
	 */
288
	RETPOLINE_RAX_BPF_JIT();
289 290 291 292 293 294 295 296

	/* out: */
	BUILD_BUG_ON(cnt - label1 != OFFSET1);
	BUILD_BUG_ON(cnt - label2 != OFFSET2);
	BUILD_BUG_ON(cnt - label3 != OFFSET3);
	*pprog = prog;
}

297 298 299 300 301 302 303
static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
			   u32 dst_reg, const u32 imm32)
{
	u8 *prog = *pprog;
	u8 b1, b2, b3;
	int cnt = 0;

304 305
	/*
	 * Optimization: if imm32 is positive, use 'mov %eax, imm32'
306 307 308 309 310 311 312 313 314 315 316
	 * (which zero-extends imm32) to save 2 bytes.
	 */
	if (sign_propagate && (s32)imm32 < 0) {
		/* 'mov %rax, imm32' sign extends imm32 */
		b1 = add_1mod(0x48, dst_reg);
		b2 = 0xC7;
		b3 = 0xC0;
		EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
		goto done;
	}

317 318
	/*
	 * Optimization: if imm32 is zero, use 'xor %eax, %eax'
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
	 * to save 3 bytes.
	 */
	if (imm32 == 0) {
		if (is_ereg(dst_reg))
			EMIT1(add_2mod(0x40, dst_reg, dst_reg));
		b2 = 0x31; /* xor */
		b3 = 0xC0;
		EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
		goto done;
	}

	/* mov %eax, imm32 */
	if (is_ereg(dst_reg))
		EMIT1(add_1mod(0x40, dst_reg));
	EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
done:
	*pprog = prog;
}

static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
			   const u32 imm32_hi, const u32 imm32_lo)
{
	u8 *prog = *pprog;
	int cnt = 0;

	if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
345 346
		/*
		 * For emitting plain u32, where sign bit must not be
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
		 * propagated LLVM tends to load imm64 over mov32
		 * directly, so save couple of bytes by just doing
		 * 'mov %eax, imm32' instead.
		 */
		emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
	} else {
		/* movabsq %rax, imm64 */
		EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
		EMIT(imm32_lo, 4);
		EMIT(imm32_hi, 4);
	}

	*pprog = prog;
}

362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
{
	u8 *prog = *pprog;
	int cnt = 0;

	if (is64) {
		/* mov dst, src */
		EMIT_mov(dst_reg, src_reg);
	} else {
		/* mov32 dst, src */
		if (is_ereg(dst_reg) || is_ereg(src_reg))
			EMIT1(add_2mod(0x40, dst_reg, src_reg));
		EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
	}

	*pprog = prog;
}

380 381 382 383 384 385 386 387 388 389 390
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
		  int oldproglen, struct jit_context *ctx)
{
	struct bpf_insn *insn = bpf_prog->insnsi;
	int insn_cnt = bpf_prog->len;
	bool seen_exit = false;
	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
	int i, cnt = 0;
	int proglen = 0;
	u8 *prog = temp;

391 392
	emit_prologue(&prog, bpf_prog->aux->stack_depth,
		      bpf_prog_was_classic(bpf_prog));
393
	addrs[0] = prog - temp;
394

395
	for (i = 1; i <= insn_cnt; i++, insn++) {
396 397 398
		const s32 imm32 = insn->imm;
		u32 dst_reg = insn->dst_reg;
		u32 src_reg = insn->src_reg;
399
		u8 b2 = 0, b3 = 0;
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
		s64 jmp_offset;
		u8 jmp_cond;
		int ilen;
		u8 *func;

		switch (insn->code) {
			/* ALU */
		case BPF_ALU | BPF_ADD | BPF_X:
		case BPF_ALU | BPF_SUB | BPF_X:
		case BPF_ALU | BPF_AND | BPF_X:
		case BPF_ALU | BPF_OR | BPF_X:
		case BPF_ALU | BPF_XOR | BPF_X:
		case BPF_ALU64 | BPF_ADD | BPF_X:
		case BPF_ALU64 | BPF_SUB | BPF_X:
		case BPF_ALU64 | BPF_AND | BPF_X:
		case BPF_ALU64 | BPF_OR | BPF_X:
		case BPF_ALU64 | BPF_XOR | BPF_X:
			switch (BPF_OP(insn->code)) {
			case BPF_ADD: b2 = 0x01; break;
			case BPF_SUB: b2 = 0x29; break;
			case BPF_AND: b2 = 0x21; break;
			case BPF_OR: b2 = 0x09; break;
			case BPF_XOR: b2 = 0x31; break;
423
			}
424
			if (BPF_CLASS(insn->code) == BPF_ALU64)
425 426 427 428
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
429
			break;
430

431 432
		case BPF_ALU64 | BPF_MOV | BPF_X:
		case BPF_ALU | BPF_MOV | BPF_X:
433 434 435
			emit_mov_reg(&prog,
				     BPF_CLASS(insn->code) == BPF_ALU64,
				     dst_reg, src_reg);
436
			break;
437

438
			/* neg dst */
439 440 441
		case BPF_ALU | BPF_NEG:
		case BPF_ALU64 | BPF_NEG:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
442 443 444 445
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
			EMIT2(0xF7, add_1reg(0xD8, dst_reg));
446 447 448 449 450 451 452 453 454 455 456 457 458
			break;

		case BPF_ALU | BPF_ADD | BPF_K:
		case BPF_ALU | BPF_SUB | BPF_K:
		case BPF_ALU | BPF_AND | BPF_K:
		case BPF_ALU | BPF_OR | BPF_K:
		case BPF_ALU | BPF_XOR | BPF_K:
		case BPF_ALU64 | BPF_ADD | BPF_K:
		case BPF_ALU64 | BPF_SUB | BPF_K:
		case BPF_ALU64 | BPF_AND | BPF_K:
		case BPF_ALU64 | BPF_OR | BPF_K:
		case BPF_ALU64 | BPF_XOR | BPF_K:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
459 460 461
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
462

463 464
			/*
			 * b3 holds 'normal' opcode, b2 short form only valid
465 466
			 * in case dst is eax/rax.
			 */
467
			switch (BPF_OP(insn->code)) {
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
			case BPF_ADD:
				b3 = 0xC0;
				b2 = 0x05;
				break;
			case BPF_SUB:
				b3 = 0xE8;
				b2 = 0x2D;
				break;
			case BPF_AND:
				b3 = 0xE0;
				b2 = 0x25;
				break;
			case BPF_OR:
				b3 = 0xC8;
				b2 = 0x0D;
				break;
			case BPF_XOR:
				b3 = 0xF0;
				b2 = 0x35;
				break;
488 489
			}

490 491
			if (is_imm8(imm32))
				EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
492 493
			else if (is_axreg(dst_reg))
				EMIT1_off32(b2, imm32);
494
			else
495
				EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
496 497 498 499
			break;

		case BPF_ALU64 | BPF_MOV | BPF_K:
		case BPF_ALU | BPF_MOV | BPF_K:
500 501
			emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
				       dst_reg, imm32);
502 503
			break;

504
		case BPF_LD | BPF_IMM | BPF_DW:
505
			emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
506 507 508 509
			insn++;
			i++;
			break;

510
			/* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
511 512 513 514 515 516 517 518 519 520 521 522
		case BPF_ALU | BPF_MOD | BPF_X:
		case BPF_ALU | BPF_DIV | BPF_X:
		case BPF_ALU | BPF_MOD | BPF_K:
		case BPF_ALU | BPF_DIV | BPF_K:
		case BPF_ALU64 | BPF_MOD | BPF_X:
		case BPF_ALU64 | BPF_DIV | BPF_X:
		case BPF_ALU64 | BPF_MOD | BPF_K:
		case BPF_ALU64 | BPF_DIV | BPF_K:
			EMIT1(0x50); /* push rax */
			EMIT1(0x52); /* push rdx */

			if (BPF_SRC(insn->code) == BPF_X)
523 524
				/* mov r11, src_reg */
				EMIT_mov(AUX_REG, src_reg);
525
			else
526 527
				/* mov r11, imm32 */
				EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
528

529 530
			/* mov rax, dst_reg */
			EMIT_mov(BPF_REG_0, dst_reg);
531

532 533
			/*
			 * xor edx, edx
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
			 * equivalent to 'xor rdx, rdx', but one byte less
			 */
			EMIT2(0x31, 0xd2);

			if (BPF_CLASS(insn->code) == BPF_ALU64)
				/* div r11 */
				EMIT3(0x49, 0xF7, 0xF3);
			else
				/* div r11d */
				EMIT3(0x41, 0xF7, 0xF3);

			if (BPF_OP(insn->code) == BPF_MOD)
				/* mov r11, rdx */
				EMIT3(0x49, 0x89, 0xD3);
			else
				/* mov r11, rax */
				EMIT3(0x49, 0x89, 0xC3);

			EMIT1(0x5A); /* pop rdx */
			EMIT1(0x58); /* pop rax */

555 556
			/* mov dst_reg, r11 */
			EMIT_mov(dst_reg, AUX_REG);
557 558 559 560 561 562
			break;

		case BPF_ALU | BPF_MUL | BPF_K:
		case BPF_ALU | BPF_MUL | BPF_X:
		case BPF_ALU64 | BPF_MUL | BPF_K:
		case BPF_ALU64 | BPF_MUL | BPF_X:
563 564 565
		{
			bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;

566 567 568 569
			if (dst_reg != BPF_REG_0)
				EMIT1(0x50); /* push rax */
			if (dst_reg != BPF_REG_3)
				EMIT1(0x52); /* push rdx */
570

571 572
			/* mov r11, dst_reg */
			EMIT_mov(AUX_REG, dst_reg);
573 574

			if (BPF_SRC(insn->code) == BPF_X)
575
				emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
576
			else
577
				emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
578

579
			if (is64)
580 581 582 583 584 585
				EMIT1(add_1mod(0x48, AUX_REG));
			else if (is_ereg(AUX_REG))
				EMIT1(add_1mod(0x40, AUX_REG));
			/* mul(q) r11 */
			EMIT2(0xF7, add_1reg(0xE0, AUX_REG));

586 587 588 589 590 591 592
			if (dst_reg != BPF_REG_3)
				EMIT1(0x5A); /* pop rdx */
			if (dst_reg != BPF_REG_0) {
				/* mov dst_reg, rax */
				EMIT_mov(dst_reg, BPF_REG_0);
				EMIT1(0x58); /* pop rax */
			}
593
			break;
594
		}
595
			/* Shifts */
596 597 598 599 600 601 602
		case BPF_ALU | BPF_LSH | BPF_K:
		case BPF_ALU | BPF_RSH | BPF_K:
		case BPF_ALU | BPF_ARSH | BPF_K:
		case BPF_ALU64 | BPF_LSH | BPF_K:
		case BPF_ALU64 | BPF_RSH | BPF_K:
		case BPF_ALU64 | BPF_ARSH | BPF_K:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
603 604 605
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
606 607 608 609 610 611

			switch (BPF_OP(insn->code)) {
			case BPF_LSH: b3 = 0xE0; break;
			case BPF_RSH: b3 = 0xE8; break;
			case BPF_ARSH: b3 = 0xF8; break;
			}
612 613 614 615 616

			if (imm32 == 1)
				EMIT2(0xD1, add_1reg(b3, dst_reg));
			else
				EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
617 618
			break;

619 620 621 622 623 624 625
		case BPF_ALU | BPF_LSH | BPF_X:
		case BPF_ALU | BPF_RSH | BPF_X:
		case BPF_ALU | BPF_ARSH | BPF_X:
		case BPF_ALU64 | BPF_LSH | BPF_X:
		case BPF_ALU64 | BPF_RSH | BPF_X:
		case BPF_ALU64 | BPF_ARSH | BPF_X:

626
			/* Check for bad case when dst_reg == rcx */
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
			if (dst_reg == BPF_REG_4) {
				/* mov r11, dst_reg */
				EMIT_mov(AUX_REG, dst_reg);
				dst_reg = AUX_REG;
			}

			if (src_reg != BPF_REG_4) { /* common case */
				EMIT1(0x51); /* push rcx */

				/* mov rcx, src_reg */
				EMIT_mov(BPF_REG_4, src_reg);
			}

			/* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
			if (BPF_CLASS(insn->code) == BPF_ALU64)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));

			switch (BPF_OP(insn->code)) {
			case BPF_LSH: b3 = 0xE0; break;
			case BPF_RSH: b3 = 0xE8; break;
			case BPF_ARSH: b3 = 0xF8; break;
			}
			EMIT2(0xD3, add_1reg(b3, dst_reg));

			if (src_reg != BPF_REG_4)
				EMIT1(0x59); /* pop rcx */

			if (insn->dst_reg == BPF_REG_4)
				/* mov dst_reg, r11 */
				EMIT_mov(insn->dst_reg, AUX_REG);
			break;

661
		case BPF_ALU | BPF_END | BPF_FROM_BE:
662
			switch (imm32) {
663
			case 16:
664
				/* Emit 'ror %ax, 8' to swap lower 2 bytes */
665
				EMIT1(0x66);
666
				if (is_ereg(dst_reg))
667
					EMIT1(0x41);
668
				EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
669

670
				/* Emit 'movzwl eax, ax' */
671 672 673 674 675
				if (is_ereg(dst_reg))
					EMIT3(0x45, 0x0F, 0xB7);
				else
					EMIT2(0x0F, 0xB7);
				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
676 677
				break;
			case 32:
678
				/* Emit 'bswap eax' to swap lower 4 bytes */
679
				if (is_ereg(dst_reg))
680
					EMIT2(0x41, 0x0F);
681
				else
682
					EMIT1(0x0F);
683
				EMIT1(add_1reg(0xC8, dst_reg));
684
				break;
685
			case 64:
686
				/* Emit 'bswap rax' to swap 8 bytes */
687 688
				EMIT3(add_1mod(0x48, dst_reg), 0x0F,
				      add_1reg(0xC8, dst_reg));
689 690
				break;
			}
691 692 693
			break;

		case BPF_ALU | BPF_END | BPF_FROM_LE:
694 695
			switch (imm32) {
			case 16:
696 697
				/*
				 * Emit 'movzwl eax, ax' to zero extend 16-bit
698 699 700 701 702 703 704 705 706
				 * into 64 bit
				 */
				if (is_ereg(dst_reg))
					EMIT3(0x45, 0x0F, 0xB7);
				else
					EMIT2(0x0F, 0xB7);
				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
				break;
			case 32:
707
				/* Emit 'mov eax, eax' to clear upper 32-bits */
708 709 710 711 712 713 714 715
				if (is_ereg(dst_reg))
					EMIT1(0x45);
				EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
				break;
			case 64:
				/* nop */
				break;
			}
716 717
			break;

718
			/* ST: *(u8*)(dst_reg + off) = imm */
719
		case BPF_ST | BPF_MEM | BPF_B:
720
			if (is_ereg(dst_reg))
721 722 723 724 725
				EMIT2(0x41, 0xC6);
			else
				EMIT1(0xC6);
			goto st;
		case BPF_ST | BPF_MEM | BPF_H:
726
			if (is_ereg(dst_reg))
727 728 729 730 731
				EMIT3(0x66, 0x41, 0xC7);
			else
				EMIT2(0x66, 0xC7);
			goto st;
		case BPF_ST | BPF_MEM | BPF_W:
732
			if (is_ereg(dst_reg))
733 734 735 736 737
				EMIT2(0x41, 0xC7);
			else
				EMIT1(0xC7);
			goto st;
		case BPF_ST | BPF_MEM | BPF_DW:
738
			EMIT2(add_1mod(0x48, dst_reg), 0xC7);
739 740

st:			if (is_imm8(insn->off))
741
				EMIT2(add_1reg(0x40, dst_reg), insn->off);
742
			else
743
				EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
744

745
			EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
746 747
			break;

748
			/* STX: *(u8*)(dst_reg + off) = src_reg */
749
		case BPF_STX | BPF_MEM | BPF_B:
750
			/* Emit 'mov byte ptr [rax + off], al' */
751
			if (is_ereg(dst_reg) || is_ereg(src_reg) ||
752
			    /* We have to add extra byte for x86 SIL, DIL regs */
753 754
			    src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
755 756 757 758
			else
				EMIT1(0x88);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_H:
759 760
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
761 762 763 764
			else
				EMIT2(0x66, 0x89);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_W:
765 766
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
767 768 769 770
			else
				EMIT1(0x89);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_DW:
771
			EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
772
stx:			if (is_imm8(insn->off))
773
				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
774
			else
775
				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
776 777 778
					    insn->off);
			break;

779
			/* LDX: dst_reg = *(u8*)(src_reg + off) */
780
		case BPF_LDX | BPF_MEM | BPF_B:
781
			/* Emit 'movzx rax, byte ptr [rax + off]' */
782
			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
783 784
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_H:
785
			/* Emit 'movzx rax, word ptr [rax + off]' */
786
			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
787 788
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_W:
789
			/* Emit 'mov eax, dword ptr [rax+0x14]' */
790 791
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
792 793 794 795
			else
				EMIT1(0x8B);
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_DW:
796
			/* Emit 'mov rax, qword ptr [rax+0x14]' */
797
			EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
798 799 800
ldx:			/*
			 * If insn->off == 0 we can save one extra byte, but
			 * special case of x86 R13 which always needs an offset
801 802 803
			 * is not worth the hassle
			 */
			if (is_imm8(insn->off))
804
				EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
805
			else
806
				EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
807 808 809
					    insn->off);
			break;

810
			/* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
811
		case BPF_STX | BPF_XADD | BPF_W:
812
			/* Emit 'lock add dword ptr [rax + off], eax' */
813 814
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
815 816 817 818
			else
				EMIT2(0xF0, 0x01);
			goto xadd;
		case BPF_STX | BPF_XADD | BPF_DW:
819
			EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
820
xadd:			if (is_imm8(insn->off))
821
				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
822
			else
823
				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
824 825 826 827 828
					    insn->off);
			break;

			/* call */
		case BPF_JMP | BPF_CALL:
829
			func = (u8 *) __bpf_call_base + imm32;
830
			jmp_offset = func - (image + addrs[i]);
831
			if (!imm32 || !is_simm32(jmp_offset)) {
832
				pr_err("unsupported BPF func %d addr %p image %p\n",
833
				       imm32, func, image);
834 835 836 837 838
				return -EINVAL;
			}
			EMIT1_off32(0xE8, jmp_offset);
			break;

839
		case BPF_JMP | BPF_TAIL_CALL:
840 841 842
			emit_bpf_tail_call(&prog);
			break;

843 844 845 846
			/* cond jump */
		case BPF_JMP | BPF_JEQ | BPF_X:
		case BPF_JMP | BPF_JNE | BPF_X:
		case BPF_JMP | BPF_JGT | BPF_X:
847
		case BPF_JMP | BPF_JLT | BPF_X:
848
		case BPF_JMP | BPF_JGE | BPF_X:
849
		case BPF_JMP | BPF_JLE | BPF_X:
850
		case BPF_JMP | BPF_JSGT | BPF_X:
851
		case BPF_JMP | BPF_JSLT | BPF_X:
852
		case BPF_JMP | BPF_JSGE | BPF_X:
853
		case BPF_JMP | BPF_JSLE | BPF_X:
854 855 856 857 858 859 860 861 862 863
		case BPF_JMP32 | BPF_JEQ | BPF_X:
		case BPF_JMP32 | BPF_JNE | BPF_X:
		case BPF_JMP32 | BPF_JGT | BPF_X:
		case BPF_JMP32 | BPF_JLT | BPF_X:
		case BPF_JMP32 | BPF_JGE | BPF_X:
		case BPF_JMP32 | BPF_JLE | BPF_X:
		case BPF_JMP32 | BPF_JSGT | BPF_X:
		case BPF_JMP32 | BPF_JSLT | BPF_X:
		case BPF_JMP32 | BPF_JSGE | BPF_X:
		case BPF_JMP32 | BPF_JSLE | BPF_X:
864
			/* cmp dst_reg, src_reg */
865 866 867 868 869
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
870 871 872
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JSET | BPF_X:
873
		case BPF_JMP32 | BPF_JSET | BPF_X:
874
			/* test dst_reg, src_reg */
875 876 877 878 879
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
880 881 882
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JSET | BPF_K:
883
		case BPF_JMP32 | BPF_JSET | BPF_K:
884
			/* test dst_reg, imm32 */
885 886 887 888
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
889
			EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
890 891 892 893 894
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JEQ | BPF_K:
		case BPF_JMP | BPF_JNE | BPF_K:
		case BPF_JMP | BPF_JGT | BPF_K:
895
		case BPF_JMP | BPF_JLT | BPF_K:
896
		case BPF_JMP | BPF_JGE | BPF_K:
897
		case BPF_JMP | BPF_JLE | BPF_K:
898
		case BPF_JMP | BPF_JSGT | BPF_K:
899
		case BPF_JMP | BPF_JSLT | BPF_K:
900
		case BPF_JMP | BPF_JSGE | BPF_K:
901
		case BPF_JMP | BPF_JSLE | BPF_K:
902 903 904 905 906 907 908 909 910 911
		case BPF_JMP32 | BPF_JEQ | BPF_K:
		case BPF_JMP32 | BPF_JNE | BPF_K:
		case BPF_JMP32 | BPF_JGT | BPF_K:
		case BPF_JMP32 | BPF_JLT | BPF_K:
		case BPF_JMP32 | BPF_JGE | BPF_K:
		case BPF_JMP32 | BPF_JLE | BPF_K:
		case BPF_JMP32 | BPF_JSGT | BPF_K:
		case BPF_JMP32 | BPF_JSLT | BPF_K:
		case BPF_JMP32 | BPF_JSGE | BPF_K:
		case BPF_JMP32 | BPF_JSLE | BPF_K:
912
			/* cmp dst_reg, imm8/32 */
913 914 915 916
			if (BPF_CLASS(insn->code) == BPF_JMP)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
917

918 919
			if (is_imm8(imm32))
				EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
920
			else
921
				EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
922

923
emit_cond_jmp:		/* Convert BPF opcode to x86 */
924 925 926 927 928 929 930 931 932 933 934 935
			switch (BPF_OP(insn->code)) {
			case BPF_JEQ:
				jmp_cond = X86_JE;
				break;
			case BPF_JSET:
			case BPF_JNE:
				jmp_cond = X86_JNE;
				break;
			case BPF_JGT:
				/* GT is unsigned '>', JA in x86 */
				jmp_cond = X86_JA;
				break;
936 937 938 939
			case BPF_JLT:
				/* LT is unsigned '<', JB in x86 */
				jmp_cond = X86_JB;
				break;
940 941 942 943
			case BPF_JGE:
				/* GE is unsigned '>=', JAE in x86 */
				jmp_cond = X86_JAE;
				break;
944 945 946 947
			case BPF_JLE:
				/* LE is unsigned '<=', JBE in x86 */
				jmp_cond = X86_JBE;
				break;
948
			case BPF_JSGT:
949
				/* Signed '>', GT in x86 */
950 951
				jmp_cond = X86_JG;
				break;
952
			case BPF_JSLT:
953
				/* Signed '<', LT in x86 */
954 955
				jmp_cond = X86_JL;
				break;
956
			case BPF_JSGE:
957
				/* Signed '>=', GE in x86 */
958 959
				jmp_cond = X86_JGE;
				break;
960
			case BPF_JSLE:
961
				/* Signed '<=', LE in x86 */
962 963
				jmp_cond = X86_JLE;
				break;
964
			default: /* to silence GCC warning */
965 966 967 968 969 970 971 972 973 974 975 976 977
				return -EFAULT;
			}
			jmp_offset = addrs[i + insn->off] - addrs[i];
			if (is_imm8(jmp_offset)) {
				EMIT2(jmp_cond, jmp_offset);
			} else if (is_simm32(jmp_offset)) {
				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
			} else {
				pr_err("cond_jmp gen bug %llx\n", jmp_offset);
				return -EFAULT;
			}

			break;
978

979
		case BPF_JMP | BPF_JA:
980 981 982 983 984 985 986 987 988 989 990
			if (insn->off == -1)
				/* -1 jmp instructions will always jump
				 * backwards two bytes. Explicitly handling
				 * this case avoids wasting too many passes
				 * when there are long sequences of replaced
				 * dead code.
				 */
				jmp_offset = -2;
			else
				jmp_offset = addrs[i + insn->off] - addrs[i];

991
			if (!jmp_offset)
992
				/* Optimize out nop jumps */
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
				break;
emit_jmp:
			if (is_imm8(jmp_offset)) {
				EMIT2(0xEB, jmp_offset);
			} else if (is_simm32(jmp_offset)) {
				EMIT1_off32(0xE9, jmp_offset);
			} else {
				pr_err("jmp gen bug %llx\n", jmp_offset);
				return -EFAULT;
			}
			break;

		case BPF_JMP | BPF_EXIT:
1006
			if (seen_exit) {
1007 1008 1009
				jmp_offset = ctx->cleanup_addr - addrs[i];
				goto emit_jmp;
			}
1010
			seen_exit = true;
1011
			/* Update cleanup_addr */
1012
			ctx->cleanup_addr = proglen;
1013 1014 1015 1016 1017 1018 1019 1020
			if (!bpf_prog_was_classic(bpf_prog))
				EMIT1(0x5B); /* get rid of tail_call_cnt */
			EMIT2(0x41, 0x5F);   /* pop r15 */
			EMIT2(0x41, 0x5E);   /* pop r14 */
			EMIT2(0x41, 0x5D);   /* pop r13 */
			EMIT1(0x5B);         /* pop rbx */
			EMIT1(0xC9);         /* leave */
			EMIT1(0xC3);         /* ret */
1021 1022
			break;

1023
		default:
1024 1025
			/*
			 * By design x86-64 JIT should support all BPF instructions.
1026
			 * This error will be seen if new instruction was added
1027 1028
			 * to the interpreter, but not to the JIT, or if there is
			 * junk in bpf_prog.
1029 1030
			 */
			pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1031 1032
			return -EINVAL;
		}
1033

1034
		ilen = prog - temp;
1035
		if (ilen > BPF_MAX_INSN_SIZE) {
1036
			pr_err("bpf_jit: fatal insn size error\n");
1037 1038 1039
			return -EFAULT;
		}

1040 1041
		if (image) {
			if (unlikely(proglen + ilen > oldproglen)) {
1042
				pr_err("bpf_jit: fatal error\n");
1043
				return -EFAULT;
1044
			}
1045
			memcpy(image + proglen, temp, ilen);
1046
		}
1047 1048 1049 1050 1051 1052 1053
		proglen += ilen;
		addrs[i] = proglen;
		prog = temp;
	}
	return proglen;
}

1054 1055 1056 1057 1058 1059 1060 1061
struct x64_jit_data {
	struct bpf_binary_header *header;
	int *addrs;
	u8 *image;
	int proglen;
	struct jit_context ctx;
};

1062
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1063 1064
{
	struct bpf_binary_header *header = NULL;
1065
	struct bpf_prog *tmp, *orig_prog = prog;
1066
	struct x64_jit_data *jit_data;
1067 1068
	int proglen, oldproglen = 0;
	struct jit_context ctx = {};
1069
	bool tmp_blinded = false;
1070
	bool extra_pass = false;
1071 1072 1073 1074 1075
	u8 *image = NULL;
	int *addrs;
	int pass;
	int i;

1076
	if (!prog->jit_requested)
1077 1078 1079
		return orig_prog;

	tmp = bpf_jit_blind_constants(prog);
1080 1081
	/*
	 * If blinding was requested and we failed during blinding,
1082 1083 1084 1085 1086 1087 1088 1089
	 * we must fall back to the interpreter.
	 */
	if (IS_ERR(tmp))
		return orig_prog;
	if (tmp != prog) {
		tmp_blinded = true;
		prog = tmp;
	}
1090

1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108
	jit_data = prog->aux->jit_data;
	if (!jit_data) {
		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
		if (!jit_data) {
			prog = orig_prog;
			goto out;
		}
		prog->aux->jit_data = jit_data;
	}
	addrs = jit_data->addrs;
	if (addrs) {
		ctx = jit_data->ctx;
		oldproglen = jit_data->proglen;
		image = jit_data->image;
		header = jit_data->header;
		extra_pass = true;
		goto skip_init_addrs;
	}
1109
	addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
1110 1111
	if (!addrs) {
		prog = orig_prog;
1112
		goto out_addrs;
1113
	}
1114

1115 1116 1117
	/*
	 * Before first pass, make a rough estimation of addrs[]
	 * each BPF instruction is translated to less than 64 bytes
1118
	 */
1119
	for (proglen = 0, i = 0; i <= prog->len; i++) {
1120 1121 1122 1123
		proglen += 64;
		addrs[i] = proglen;
	}
	ctx.cleanup_addr = proglen;
1124
skip_init_addrs:
1125

1126 1127 1128
	/*
	 * JITed image shrinks with every pass and the loop iterates
	 * until the image stops shrinking. Very large BPF programs
1129
	 * may converge on the last pass. In such case do one more
1130
	 * pass to emit the final image.
1131
	 */
1132
	for (pass = 0; pass < 20 || image; pass++) {
1133 1134
		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
		if (proglen <= 0) {
1135
out_image:
1136 1137
			image = NULL;
			if (header)
1138
				bpf_jit_binary_free(header);
1139 1140
			prog = orig_prog;
			goto out_addrs;
1141
		}
1142
		if (image) {
1143
			if (proglen != oldproglen) {
1144 1145
				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
				       proglen, oldproglen);
1146
				goto out_image;
1147
			}
1148 1149 1150
			break;
		}
		if (proglen == oldproglen) {
1151 1152
			header = bpf_jit_binary_alloc(proglen, &image,
						      1, jit_fill_hole);
1153 1154 1155 1156
			if (!header) {
				prog = orig_prog;
				goto out_addrs;
			}
1157 1158
		}
		oldproglen = proglen;
1159
		cond_resched();
1160
	}
1161

1162
	if (bpf_jit_enable > 1)
1163
		bpf_jit_dump(prog->len, proglen, pass + 1, image);
1164 1165

	if (image) {
1166 1167 1168 1169 1170 1171 1172 1173 1174
		if (!prog->is_func || extra_pass) {
			bpf_jit_binary_lock_ro(header);
		} else {
			jit_data->addrs = addrs;
			jit_data->ctx = ctx;
			jit_data->proglen = proglen;
			jit_data->image = image;
			jit_data->header = header;
		}
1175
		prog->bpf_func = (void *)image;
1176
		prog->jited = 1;
1177
		prog->jited_len = proglen;
1178 1179
	} else {
		prog = orig_prog;
1180
	}
1181

1182
	if (!image || !prog->is_func || extra_pass) {
M
Martin KaFai Lau 已提交
1183
		if (image)
1184
			bpf_prog_fill_jited_linfo(prog, addrs + 1);
1185
out_addrs:
1186 1187 1188 1189
		kfree(addrs);
		kfree(jit_data);
		prog->aux->jit_data = NULL;
	}
1190 1191 1192 1193
out:
	if (tmp_blinded)
		bpf_jit_prog_release_other(prog, prog == orig_prog ?
					   tmp : orig_prog);
1194
	return prog;
1195
}