bpf_jit_comp.c 33.1 KB
Newer Older
1 2
/* bpf_jit_comp.c : BPF JIT compiler
 *
3
 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
4
 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
5 6 7 8 9 10 11 12
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; version 2
 * of the License.
 */
#include <linux/netdevice.h>
#include <linux/filter.h>
13
#include <linux/if_vlan.h>
14 15
#include <linux/bpf.h>

L
Laura Abbott 已提交
16
#include <asm/set_memory.h>
17
#include <asm/nospec-branch.h>
18 19 20 21

/*
 * assembly code in arch/x86/net/bpf_jit.S
 */
22
extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
23
extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
24
extern u8 sk_load_byte_positive_offset[];
25
extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
26
extern u8 sk_load_byte_negative_offset[];
27

28
static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
29 30 31 32 33 34 35 36 37 38 39 40
{
	if (len == 1)
		*ptr = bytes;
	else if (len == 2)
		*(u16 *)ptr = bytes;
	else {
		*(u32 *)ptr = bytes;
		barrier();
	}
	return ptr + len;
}

41 42
#define EMIT(bytes, len) \
	do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
43 44 45 46 47

#define EMIT1(b1)		EMIT(b1, 1)
#define EMIT2(b1, b2)		EMIT((b1) + ((b2) << 8), 2)
#define EMIT3(b1, b2, b3)	EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
48 49 50 51 52 53 54 55
#define EMIT1_off32(b1, off) \
	do {EMIT1(b1); EMIT(off, 4); } while (0)
#define EMIT2_off32(b1, b2, off) \
	do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
#define EMIT3_off32(b1, b2, b3, off) \
	do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
#define EMIT4_off32(b1, b2, b3, b4, off) \
	do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
56

57
static bool is_imm8(int value)
58 59 60 61
{
	return value <= 127 && value >= -128;
}

62
static bool is_simm32(s64 value)
63
{
64 65 66 67 68 69
	return value == (s64)(s32)value;
}

static bool is_uimm32(u64 value)
{
	return value == (u64)(u32)value;
70 71
}

72 73 74 75
/* mov dst, src */
#define EMIT_mov(DST, SRC) \
	do {if (DST != SRC) \
		EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
	} while (0)

static int bpf_size_to_x86_bytes(int bpf_size)
{
	if (bpf_size == BPF_W)
		return 4;
	else if (bpf_size == BPF_H)
		return 2;
	else if (bpf_size == BPF_B)
		return 1;
	else if (bpf_size == BPF_DW)
		return 4; /* imm32 */
	else
		return 0;
}
91 92 93 94 95 96 97 98 99 100

/* list of x86 cond jumps opcodes (. + s8)
 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
 */
#define X86_JB  0x72
#define X86_JAE 0x73
#define X86_JE  0x74
#define X86_JNE 0x75
#define X86_JBE 0x76
#define X86_JA  0x77
101
#define X86_JL  0x7C
102
#define X86_JGE 0x7D
103
#define X86_JLE 0x7E
104
#define X86_JG  0x7F
105

106 107
#define CHOOSE_LOAD_FUNC(K, func) \
	((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
108

109
/* pick a register outside of BPF range for JIT internal work */
110
#define AUX_REG (MAX_BPF_JIT_REG + 1)
111

112 113 114 115 116 117 118 119
/* The following table maps BPF registers to x64 registers.
 *
 * x64 register r12 is unused, since if used as base address
 * register in load/store instructions, it always needs an
 * extra byte of encoding and is callee saved.
 *
 *  r9 caches skb->len - skb->data_len
 * r10 caches skb->data, and used for blinding (if enabled)
120 121 122 123 124 125 126 127 128 129 130 131 132
 */
static const int reg2hex[] = {
	[BPF_REG_0] = 0,  /* rax */
	[BPF_REG_1] = 7,  /* rdi */
	[BPF_REG_2] = 6,  /* rsi */
	[BPF_REG_3] = 2,  /* rdx */
	[BPF_REG_4] = 1,  /* rcx */
	[BPF_REG_5] = 0,  /* r8 */
	[BPF_REG_6] = 3,  /* rbx callee saved */
	[BPF_REG_7] = 5,  /* r13 callee saved */
	[BPF_REG_8] = 6,  /* r14 callee saved */
	[BPF_REG_9] = 7,  /* r15 callee saved */
	[BPF_REG_FP] = 5, /* rbp readonly */
133
	[BPF_REG_AX] = 2, /* r10 temp register */
134 135 136 137 138 139 140
	[AUX_REG] = 3,    /* r11 temp register */
};

/* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
 * which need extra byte of encoding.
 * rax,rcx,...,rbp have simpler encoding
 */
141
static bool is_ereg(u32 reg)
142
{
143 144 145 146
	return (1 << reg) & (BIT(BPF_REG_5) |
			     BIT(AUX_REG) |
			     BIT(BPF_REG_7) |
			     BIT(BPF_REG_8) |
147 148
			     BIT(BPF_REG_9) |
			     BIT(BPF_REG_AX));
149 150
}

151 152 153 154 155
static bool is_axreg(u32 reg)
{
	return reg == BPF_REG_0;
}

156
/* add modifiers if 'reg' maps to x64 registers r8..r15 */
157
static u8 add_1mod(u8 byte, u32 reg)
158 159 160 161 162 163
{
	if (is_ereg(reg))
		byte |= 1;
	return byte;
}

164
static u8 add_2mod(u8 byte, u32 r1, u32 r2)
165 166 167 168 169 170 171 172
{
	if (is_ereg(r1))
		byte |= 1;
	if (is_ereg(r2))
		byte |= 4;
	return byte;
}

173
/* encode 'dst_reg' register into x64 opcode 'byte' */
174
static u8 add_1reg(u8 byte, u32 dst_reg)
175
{
176
	return byte + reg2hex[dst_reg];
177 178
}

179
/* encode 'dst_reg' and 'src_reg' registers into x64 opcode 'byte' */
180
static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
181
{
182
	return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
183 184
}

185 186 187 188 189 190
static void jit_fill_hole(void *area, unsigned int size)
{
	/* fill whole space with int3 instructions */
	memset(area, 0xcc, size);
}

191
struct jit_context {
192
	int cleanup_addr; /* epilogue code offset */
193
	bool seen_ld_abs;
194
	bool seen_ax_reg;
195 196
};

197 198 199 200
/* maximum number of bytes emitted while JITing one eBPF insn */
#define BPF_MAX_INSN_SIZE	128
#define BPF_INSN_SAFETY		64

201 202
#define AUX_STACK_SPACE \
	(32 /* space for rbx, r13, r14, r15 */ + \
203 204
	 8 /* space for skb_copy_bits() buffer */)

205
#define PROLOGUE_SIZE 37
206 207 208 209

/* emit x64 prologue code for BPF program and check it's size.
 * bpf_tail_call helper will skip it while jumping into another program
 */
210
static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
211
{
212 213
	u8 *prog = *pprog;
	int cnt = 0;
214

215 216
	EMIT1(0x55); /* push rbp */
	EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
217

218 219 220
	/* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */
	EMIT3_off32(0x48, 0x81, 0xEC,
		    round_up(stack_depth, 8) + AUX_STACK_SPACE);
221 222 223

	/* sub rbp, AUX_STACK_SPACE */
	EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE);
224 225 226

	/* all classic BPF filters use R6(rbx) save it */

227 228
	/* mov qword ptr [rbp+0],rbx */
	EMIT4(0x48, 0x89, 0x5D, 0);
229

230
	/* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
231 232 233 234 235 236 237
	 * as temporary, so all tcpdump filters need to spill/fill R7(r13) and
	 * R8(r14). R9(r15) spill could be made conditional, but there is only
	 * one 'bpf_error' return path out of helper functions inside bpf_jit.S
	 * The overhead of extra spill is negligible for any filter other
	 * than synthetic ones. Therefore not worth adding complexity.
	 */

238 239 240 241 242 243
	/* mov qword ptr [rbp+8],r13 */
	EMIT4(0x4C, 0x89, 0x6D, 8);
	/* mov qword ptr [rbp+16],r14 */
	EMIT4(0x4C, 0x89, 0x75, 16);
	/* mov qword ptr [rbp+24],r15 */
	EMIT4(0x4C, 0x89, 0x7D, 24);
244

245 246 247 248 249 250
	if (!ebpf_from_cbpf) {
		/* Clear the tail call counter (tail_call_cnt): for eBPF tail
		 * calls we need to reset the counter to 0. It's done in two
		 * instructions, resetting rax register to 0, and moving it
		 * to the counter location.
		 */
251

252 253 254 255 256 257 258
		/* xor eax, eax */
		EMIT2(0x31, 0xc0);
		/* mov qword ptr [rbp+32], rax */
		EMIT4(0x48, 0x89, 0x45, 32);

		BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
	}
259 260 261 262 263 264 265 266 267 268

	*pprog = prog;
}

/* generate the following code:
 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
 *   if (index >= array->map.max_entries)
 *     goto out;
 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
 *     goto out;
269
 *   prog = array->ptrs[index];
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
 *   if (prog == NULL)
 *     goto out;
 *   goto *(prog->bpf_func + prologue_size);
 * out:
 */
static void emit_bpf_tail_call(u8 **pprog)
{
	u8 *prog = *pprog;
	int label1, label2, label3;
	int cnt = 0;

	/* rdi - pointer to ctx
	 * rsi - pointer to bpf_array
	 * rdx - index in bpf_array
	 */

	/* if (index >= array->map.max_entries)
	 *   goto out;
	 */
289 290
	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
291
	      offsetof(struct bpf_array, map.max_entries));
292
#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
293 294 295 296 297 298
	EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
	label1 = cnt;

	/* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
	 *   goto out;
	 */
299
	EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
300
	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
301
#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
302 303 304
	EMIT2(X86_JA, OFFSET2);                   /* ja out */
	label2 = cnt;
	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
305
	EMIT2_off32(0x89, 0x85, 36);              /* mov dword ptr [rbp + 36], eax */
306

307
	/* prog = array->ptrs[index]; */
308
	EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
309
		    offsetof(struct bpf_array, ptrs));
310 311 312 313

	/* if (prog == NULL)
	 *   goto out;
	 */
314
	EMIT3(0x48, 0x85, 0xC0);		  /* test rax,rax */
315
#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
316 317 318 319 320 321 322 323 324 325 326 327
	EMIT2(X86_JE, OFFSET3);                   /* je out */
	label3 = cnt;

	/* goto *(prog->bpf_func + prologue_size); */
	EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
	      offsetof(struct bpf_prog, bpf_func));
	EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */

	/* now we're ready to jump into next BPF program
	 * rdi == ctx (1st arg)
	 * rax == prog->bpf_func + prologue_size
	 */
328
	RETPOLINE_RAX_BPF_JIT();
329 330 331 332 333 334 335 336

	/* out: */
	BUILD_BUG_ON(cnt - label1 != OFFSET1);
	BUILD_BUG_ON(cnt - label2 != OFFSET2);
	BUILD_BUG_ON(cnt - label3 != OFFSET3);
	*pprog = prog;
}

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356

static void emit_load_skb_data_hlen(u8 **pprog)
{
	u8 *prog = *pprog;
	int cnt = 0;

	/* r9d = skb->len - skb->data_len (headlen)
	 * r10 = skb->data
	 */
	/* mov %r9d, off32(%rdi) */
	EMIT3_off32(0x44, 0x8b, 0x8f, offsetof(struct sk_buff, len));

	/* sub %r9d, off32(%rdi) */
	EMIT3_off32(0x44, 0x2b, 0x8f, offsetof(struct sk_buff, data_len));

	/* mov %r10, off32(%rdi) */
	EMIT3_off32(0x4c, 0x8b, 0x97, offsetof(struct sk_buff, data));
	*pprog = prog;
}

357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
			   u32 dst_reg, const u32 imm32)
{
	u8 *prog = *pprog;
	u8 b1, b2, b3;
	int cnt = 0;

	/* optimization: if imm32 is positive, use 'mov %eax, imm32'
	 * (which zero-extends imm32) to save 2 bytes.
	 */
	if (sign_propagate && (s32)imm32 < 0) {
		/* 'mov %rax, imm32' sign extends imm32 */
		b1 = add_1mod(0x48, dst_reg);
		b2 = 0xC7;
		b3 = 0xC0;
		EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
		goto done;
	}

	/* optimization: if imm32 is zero, use 'xor %eax, %eax'
	 * to save 3 bytes.
	 */
	if (imm32 == 0) {
		if (is_ereg(dst_reg))
			EMIT1(add_2mod(0x40, dst_reg, dst_reg));
		b2 = 0x31; /* xor */
		b3 = 0xC0;
		EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
		goto done;
	}

	/* mov %eax, imm32 */
	if (is_ereg(dst_reg))
		EMIT1(add_1mod(0x40, dst_reg));
	EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
done:
	*pprog = prog;
}

static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
			   const u32 imm32_hi, const u32 imm32_lo)
{
	u8 *prog = *pprog;
	int cnt = 0;

	if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
		/* For emitting plain u32, where sign bit must not be
		 * propagated LLVM tends to load imm64 over mov32
		 * directly, so save couple of bytes by just doing
		 * 'mov %eax, imm32' instead.
		 */
		emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
	} else {
		/* movabsq %rax, imm64 */
		EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
		EMIT(imm32_lo, 4);
		EMIT(imm32_hi, 4);
	}

	*pprog = prog;
}

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
{
	u8 *prog = *pprog;
	int cnt = 0;

	if (is64) {
		/* mov dst, src */
		EMIT_mov(dst_reg, src_reg);
	} else {
		/* mov32 dst, src */
		if (is_ereg(dst_reg) || is_ereg(src_reg))
			EMIT1(add_2mod(0x40, dst_reg, src_reg));
		EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
	}

	*pprog = prog;
}

437 438 439 440 441 442
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
		  int oldproglen, struct jit_context *ctx)
{
	struct bpf_insn *insn = bpf_prog->insnsi;
	int insn_cnt = bpf_prog->len;
	bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
443
	bool seen_ax_reg = ctx->seen_ax_reg | (oldproglen == 0);
444 445 446 447 448 449
	bool seen_exit = false;
	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
	int i, cnt = 0;
	int proglen = 0;
	u8 *prog = temp;

450 451
	emit_prologue(&prog, bpf_prog->aux->stack_depth,
		      bpf_prog_was_classic(bpf_prog));
452

453 454
	if (seen_ld_abs)
		emit_load_skb_data_hlen(&prog);
455 456

	for (i = 0; i < insn_cnt; i++, insn++) {
457 458 459
		const s32 imm32 = insn->imm;
		u32 dst_reg = insn->dst_reg;
		u32 src_reg = insn->src_reg;
460
		u8 b2 = 0, b3 = 0;
461 462
		s64 jmp_offset;
		u8 jmp_cond;
463
		bool reload_skb_data;
464 465 466
		int ilen;
		u8 *func;

467 468 469
		if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
			ctx->seen_ax_reg = seen_ax_reg = true;

470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
		switch (insn->code) {
			/* ALU */
		case BPF_ALU | BPF_ADD | BPF_X:
		case BPF_ALU | BPF_SUB | BPF_X:
		case BPF_ALU | BPF_AND | BPF_X:
		case BPF_ALU | BPF_OR | BPF_X:
		case BPF_ALU | BPF_XOR | BPF_X:
		case BPF_ALU64 | BPF_ADD | BPF_X:
		case BPF_ALU64 | BPF_SUB | BPF_X:
		case BPF_ALU64 | BPF_AND | BPF_X:
		case BPF_ALU64 | BPF_OR | BPF_X:
		case BPF_ALU64 | BPF_XOR | BPF_X:
			switch (BPF_OP(insn->code)) {
			case BPF_ADD: b2 = 0x01; break;
			case BPF_SUB: b2 = 0x29; break;
			case BPF_AND: b2 = 0x21; break;
			case BPF_OR: b2 = 0x09; break;
			case BPF_XOR: b2 = 0x31; break;
488
			}
489
			if (BPF_CLASS(insn->code) == BPF_ALU64)
490 491 492 493
				EMIT1(add_2mod(0x48, dst_reg, src_reg));
			else if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT1(add_2mod(0x40, dst_reg, src_reg));
			EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
494
			break;
495

496 497
		case BPF_ALU64 | BPF_MOV | BPF_X:
		case BPF_ALU | BPF_MOV | BPF_X:
498 499 500
			emit_mov_reg(&prog,
				     BPF_CLASS(insn->code) == BPF_ALU64,
				     dst_reg, src_reg);
501
			break;
502

503
			/* neg dst */
504 505 506
		case BPF_ALU | BPF_NEG:
		case BPF_ALU64 | BPF_NEG:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
507 508 509 510
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
			EMIT2(0xF7, add_1reg(0xD8, dst_reg));
511 512 513 514 515 516 517 518 519 520 521 522 523
			break;

		case BPF_ALU | BPF_ADD | BPF_K:
		case BPF_ALU | BPF_SUB | BPF_K:
		case BPF_ALU | BPF_AND | BPF_K:
		case BPF_ALU | BPF_OR | BPF_K:
		case BPF_ALU | BPF_XOR | BPF_K:
		case BPF_ALU64 | BPF_ADD | BPF_K:
		case BPF_ALU64 | BPF_SUB | BPF_K:
		case BPF_ALU64 | BPF_AND | BPF_K:
		case BPF_ALU64 | BPF_OR | BPF_K:
		case BPF_ALU64 | BPF_XOR | BPF_K:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
524 525 526
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
527

528 529 530
			/* b3 holds 'normal' opcode, b2 short form only valid
			 * in case dst is eax/rax.
			 */
531
			switch (BPF_OP(insn->code)) {
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
			case BPF_ADD:
				b3 = 0xC0;
				b2 = 0x05;
				break;
			case BPF_SUB:
				b3 = 0xE8;
				b2 = 0x2D;
				break;
			case BPF_AND:
				b3 = 0xE0;
				b2 = 0x25;
				break;
			case BPF_OR:
				b3 = 0xC8;
				b2 = 0x0D;
				break;
			case BPF_XOR:
				b3 = 0xF0;
				b2 = 0x35;
				break;
552 553
			}

554 555
			if (is_imm8(imm32))
				EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
556 557
			else if (is_axreg(dst_reg))
				EMIT1_off32(b2, imm32);
558
			else
559
				EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
560 561 562 563
			break;

		case BPF_ALU64 | BPF_MOV | BPF_K:
		case BPF_ALU | BPF_MOV | BPF_K:
564 565
			emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
				       dst_reg, imm32);
566 567
			break;

568
		case BPF_LD | BPF_IMM | BPF_DW:
569
			emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
570 571 572 573
			insn++;
			i++;
			break;

574
			/* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
575 576 577 578 579 580 581 582 583 584 585 586
		case BPF_ALU | BPF_MOD | BPF_X:
		case BPF_ALU | BPF_DIV | BPF_X:
		case BPF_ALU | BPF_MOD | BPF_K:
		case BPF_ALU | BPF_DIV | BPF_K:
		case BPF_ALU64 | BPF_MOD | BPF_X:
		case BPF_ALU64 | BPF_DIV | BPF_X:
		case BPF_ALU64 | BPF_MOD | BPF_K:
		case BPF_ALU64 | BPF_DIV | BPF_K:
			EMIT1(0x50); /* push rax */
			EMIT1(0x52); /* push rdx */

			if (BPF_SRC(insn->code) == BPF_X)
587 588
				/* mov r11, src_reg */
				EMIT_mov(AUX_REG, src_reg);
589
			else
590 591
				/* mov r11, imm32 */
				EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
592

593 594
			/* mov rax, dst_reg */
			EMIT_mov(BPF_REG_0, dst_reg);
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617

			/* xor edx, edx
			 * equivalent to 'xor rdx, rdx', but one byte less
			 */
			EMIT2(0x31, 0xd2);

			if (BPF_CLASS(insn->code) == BPF_ALU64)
				/* div r11 */
				EMIT3(0x49, 0xF7, 0xF3);
			else
				/* div r11d */
				EMIT3(0x41, 0xF7, 0xF3);

			if (BPF_OP(insn->code) == BPF_MOD)
				/* mov r11, rdx */
				EMIT3(0x49, 0x89, 0xD3);
			else
				/* mov r11, rax */
				EMIT3(0x49, 0x89, 0xC3);

			EMIT1(0x5A); /* pop rdx */
			EMIT1(0x58); /* pop rax */

618 619
			/* mov dst_reg, r11 */
			EMIT_mov(dst_reg, AUX_REG);
620 621 622 623 624 625
			break;

		case BPF_ALU | BPF_MUL | BPF_K:
		case BPF_ALU | BPF_MUL | BPF_X:
		case BPF_ALU64 | BPF_MUL | BPF_K:
		case BPF_ALU64 | BPF_MUL | BPF_X:
626 627 628
		{
			bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;

629 630 631 632
			if (dst_reg != BPF_REG_0)
				EMIT1(0x50); /* push rax */
			if (dst_reg != BPF_REG_3)
				EMIT1(0x52); /* push rdx */
633

634 635
			/* mov r11, dst_reg */
			EMIT_mov(AUX_REG, dst_reg);
636 637

			if (BPF_SRC(insn->code) == BPF_X)
638
				emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
639
			else
640
				emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
641

642
			if (is64)
643 644 645 646 647 648
				EMIT1(add_1mod(0x48, AUX_REG));
			else if (is_ereg(AUX_REG))
				EMIT1(add_1mod(0x40, AUX_REG));
			/* mul(q) r11 */
			EMIT2(0xF7, add_1reg(0xE0, AUX_REG));

649 650 651 652 653 654 655
			if (dst_reg != BPF_REG_3)
				EMIT1(0x5A); /* pop rdx */
			if (dst_reg != BPF_REG_0) {
				/* mov dst_reg, rax */
				EMIT_mov(dst_reg, BPF_REG_0);
				EMIT1(0x58); /* pop rax */
			}
656
			break;
657
		}
658 659 660 661 662 663 664 665
			/* shifts */
		case BPF_ALU | BPF_LSH | BPF_K:
		case BPF_ALU | BPF_RSH | BPF_K:
		case BPF_ALU | BPF_ARSH | BPF_K:
		case BPF_ALU64 | BPF_LSH | BPF_K:
		case BPF_ALU64 | BPF_RSH | BPF_K:
		case BPF_ALU64 | BPF_ARSH | BPF_K:
			if (BPF_CLASS(insn->code) == BPF_ALU64)
666 667 668
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));
669 670 671 672 673 674

			switch (BPF_OP(insn->code)) {
			case BPF_LSH: b3 = 0xE0; break;
			case BPF_RSH: b3 = 0xE8; break;
			case BPF_ARSH: b3 = 0xF8; break;
			}
675 676 677 678 679

			if (imm32 == 1)
				EMIT2(0xD1, add_1reg(b3, dst_reg));
			else
				EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
680 681
			break;

682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723
		case BPF_ALU | BPF_LSH | BPF_X:
		case BPF_ALU | BPF_RSH | BPF_X:
		case BPF_ALU | BPF_ARSH | BPF_X:
		case BPF_ALU64 | BPF_LSH | BPF_X:
		case BPF_ALU64 | BPF_RSH | BPF_X:
		case BPF_ALU64 | BPF_ARSH | BPF_X:

			/* check for bad case when dst_reg == rcx */
			if (dst_reg == BPF_REG_4) {
				/* mov r11, dst_reg */
				EMIT_mov(AUX_REG, dst_reg);
				dst_reg = AUX_REG;
			}

			if (src_reg != BPF_REG_4) { /* common case */
				EMIT1(0x51); /* push rcx */

				/* mov rcx, src_reg */
				EMIT_mov(BPF_REG_4, src_reg);
			}

			/* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
			if (BPF_CLASS(insn->code) == BPF_ALU64)
				EMIT1(add_1mod(0x48, dst_reg));
			else if (is_ereg(dst_reg))
				EMIT1(add_1mod(0x40, dst_reg));

			switch (BPF_OP(insn->code)) {
			case BPF_LSH: b3 = 0xE0; break;
			case BPF_RSH: b3 = 0xE8; break;
			case BPF_ARSH: b3 = 0xF8; break;
			}
			EMIT2(0xD3, add_1reg(b3, dst_reg));

			if (src_reg != BPF_REG_4)
				EMIT1(0x59); /* pop rcx */

			if (insn->dst_reg == BPF_REG_4)
				/* mov dst_reg, r11 */
				EMIT_mov(insn->dst_reg, AUX_REG);
			break;

724
		case BPF_ALU | BPF_END | BPF_FROM_BE:
725
			switch (imm32) {
726 727 728
			case 16:
				/* emit 'ror %ax, 8' to swap lower 2 bytes */
				EMIT1(0x66);
729
				if (is_ereg(dst_reg))
730
					EMIT1(0x41);
731
				EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
732 733 734 735 736 737 738

				/* emit 'movzwl eax, ax' */
				if (is_ereg(dst_reg))
					EMIT3(0x45, 0x0F, 0xB7);
				else
					EMIT2(0x0F, 0xB7);
				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
739 740 741
				break;
			case 32:
				/* emit 'bswap eax' to swap lower 4 bytes */
742
				if (is_ereg(dst_reg))
743
					EMIT2(0x41, 0x0F);
744
				else
745
					EMIT1(0x0F);
746
				EMIT1(add_1reg(0xC8, dst_reg));
747
				break;
748 749
			case 64:
				/* emit 'bswap rax' to swap 8 bytes */
750 751
				EMIT3(add_1mod(0x48, dst_reg), 0x0F,
				      add_1reg(0xC8, dst_reg));
752 753
				break;
			}
754 755 756
			break;

		case BPF_ALU | BPF_END | BPF_FROM_LE:
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
			switch (imm32) {
			case 16:
				/* emit 'movzwl eax, ax' to zero extend 16-bit
				 * into 64 bit
				 */
				if (is_ereg(dst_reg))
					EMIT3(0x45, 0x0F, 0xB7);
				else
					EMIT2(0x0F, 0xB7);
				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
				break;
			case 32:
				/* emit 'mov eax, eax' to clear upper 32-bits */
				if (is_ereg(dst_reg))
					EMIT1(0x45);
				EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
				break;
			case 64:
				/* nop */
				break;
			}
778 779
			break;

780
			/* ST: *(u8*)(dst_reg + off) = imm */
781
		case BPF_ST | BPF_MEM | BPF_B:
782
			if (is_ereg(dst_reg))
783 784 785 786 787
				EMIT2(0x41, 0xC6);
			else
				EMIT1(0xC6);
			goto st;
		case BPF_ST | BPF_MEM | BPF_H:
788
			if (is_ereg(dst_reg))
789 790 791 792 793
				EMIT3(0x66, 0x41, 0xC7);
			else
				EMIT2(0x66, 0xC7);
			goto st;
		case BPF_ST | BPF_MEM | BPF_W:
794
			if (is_ereg(dst_reg))
795 796 797 798 799
				EMIT2(0x41, 0xC7);
			else
				EMIT1(0xC7);
			goto st;
		case BPF_ST | BPF_MEM | BPF_DW:
800
			EMIT2(add_1mod(0x48, dst_reg), 0xC7);
801 802

st:			if (is_imm8(insn->off))
803
				EMIT2(add_1reg(0x40, dst_reg), insn->off);
804
			else
805
				EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
806

807
			EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
808 809
			break;

810
			/* STX: *(u8*)(dst_reg + off) = src_reg */
811 812
		case BPF_STX | BPF_MEM | BPF_B:
			/* emit 'mov byte ptr [rax + off], al' */
813
			if (is_ereg(dst_reg) || is_ereg(src_reg) ||
814
			    /* have to add extra byte for x86 SIL, DIL regs */
815 816
			    src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
817 818 819 820
			else
				EMIT1(0x88);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_H:
821 822
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
823 824 825 826
			else
				EMIT2(0x66, 0x89);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_W:
827 828
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
829 830 831 832
			else
				EMIT1(0x89);
			goto stx;
		case BPF_STX | BPF_MEM | BPF_DW:
833
			EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
834
stx:			if (is_imm8(insn->off))
835
				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
836
			else
837
				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
838 839 840
					    insn->off);
			break;

841
			/* LDX: dst_reg = *(u8*)(src_reg + off) */
842 843
		case BPF_LDX | BPF_MEM | BPF_B:
			/* emit 'movzx rax, byte ptr [rax + off]' */
844
			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
845 846 847
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_H:
			/* emit 'movzx rax, word ptr [rax + off]' */
848
			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
849 850 851
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_W:
			/* emit 'mov eax, dword ptr [rax+0x14]' */
852 853
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
854 855 856 857 858
			else
				EMIT1(0x8B);
			goto ldx;
		case BPF_LDX | BPF_MEM | BPF_DW:
			/* emit 'mov rax, qword ptr [rax+0x14]' */
859
			EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
860 861 862 863 864
ldx:			/* if insn->off == 0 we can save one extra byte, but
			 * special case of x86 r13 which always needs an offset
			 * is not worth the hassle
			 */
			if (is_imm8(insn->off))
865
				EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
866
			else
867
				EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
868 869 870
					    insn->off);
			break;

871
			/* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
872 873
		case BPF_STX | BPF_XADD | BPF_W:
			/* emit 'lock add dword ptr [rax + off], eax' */
874 875
			if (is_ereg(dst_reg) || is_ereg(src_reg))
				EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
876 877 878 879
			else
				EMIT2(0xF0, 0x01);
			goto xadd;
		case BPF_STX | BPF_XADD | BPF_DW:
880
			EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
881
xadd:			if (is_imm8(insn->off))
882
				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
883
			else
884
				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
885 886 887 888 889
					    insn->off);
			break;

			/* call */
		case BPF_JMP | BPF_CALL:
890
			func = (u8 *) __bpf_call_base + imm32;
891
			jmp_offset = func - (image + addrs[i]);
892
			if (seen_ld_abs) {
893
				reload_skb_data = bpf_helper_changes_pkt_data(func);
894 895 896 897 898 899 900 901 902 903 904
				if (reload_skb_data) {
					EMIT1(0x57); /* push %rdi */
					jmp_offset += 22; /* pop, mov, sub, mov */
				} else {
					EMIT2(0x41, 0x52); /* push %r10 */
					EMIT2(0x41, 0x51); /* push %r9 */
					/* need to adjust jmp offset, since
					 * pop %r9, pop %r10 take 4 bytes after call insn
					 */
					jmp_offset += 4;
				}
905
			}
906
			if (!imm32 || !is_simm32(jmp_offset)) {
907
				pr_err("unsupported bpf func %d addr %p image %p\n",
908
				       imm32, func, image);
909 910 911
				return -EINVAL;
			}
			EMIT1_off32(0xE8, jmp_offset);
912
			if (seen_ld_abs) {
913 914 915 916 917 918 919
				if (reload_skb_data) {
					EMIT1(0x5F); /* pop %rdi */
					emit_load_skb_data_hlen(&prog);
				} else {
					EMIT2(0x41, 0x59); /* pop %r9 */
					EMIT2(0x41, 0x5A); /* pop %r10 */
				}
920 921 922
			}
			break;

923
		case BPF_JMP | BPF_TAIL_CALL:
924 925 926
			emit_bpf_tail_call(&prog);
			break;

927 928 929 930
			/* cond jump */
		case BPF_JMP | BPF_JEQ | BPF_X:
		case BPF_JMP | BPF_JNE | BPF_X:
		case BPF_JMP | BPF_JGT | BPF_X:
931
		case BPF_JMP | BPF_JLT | BPF_X:
932
		case BPF_JMP | BPF_JGE | BPF_X:
933
		case BPF_JMP | BPF_JLE | BPF_X:
934
		case BPF_JMP | BPF_JSGT | BPF_X:
935
		case BPF_JMP | BPF_JSLT | BPF_X:
936
		case BPF_JMP | BPF_JSGE | BPF_X:
937
		case BPF_JMP | BPF_JSLE | BPF_X:
938 939 940
			/* cmp dst_reg, src_reg */
			EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x39,
			      add_2reg(0xC0, dst_reg, src_reg));
941 942 943
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JSET | BPF_X:
944 945 946
			/* test dst_reg, src_reg */
			EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x85,
			      add_2reg(0xC0, dst_reg, src_reg));
947 948 949
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JSET | BPF_K:
950 951 952
			/* test dst_reg, imm32 */
			EMIT1(add_1mod(0x48, dst_reg));
			EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
953 954 955 956 957
			goto emit_cond_jmp;

		case BPF_JMP | BPF_JEQ | BPF_K:
		case BPF_JMP | BPF_JNE | BPF_K:
		case BPF_JMP | BPF_JGT | BPF_K:
958
		case BPF_JMP | BPF_JLT | BPF_K:
959
		case BPF_JMP | BPF_JGE | BPF_K:
960
		case BPF_JMP | BPF_JLE | BPF_K:
961
		case BPF_JMP | BPF_JSGT | BPF_K:
962
		case BPF_JMP | BPF_JSLT | BPF_K:
963
		case BPF_JMP | BPF_JSGE | BPF_K:
964
		case BPF_JMP | BPF_JSLE | BPF_K:
965 966
			/* cmp dst_reg, imm8/32 */
			EMIT1(add_1mod(0x48, dst_reg));
967

968 969
			if (is_imm8(imm32))
				EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
970
			else
971
				EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
972 973 974 975 976 977 978 979 980 981 982 983 984 985

emit_cond_jmp:		/* convert BPF opcode to x86 */
			switch (BPF_OP(insn->code)) {
			case BPF_JEQ:
				jmp_cond = X86_JE;
				break;
			case BPF_JSET:
			case BPF_JNE:
				jmp_cond = X86_JNE;
				break;
			case BPF_JGT:
				/* GT is unsigned '>', JA in x86 */
				jmp_cond = X86_JA;
				break;
986 987 988 989
			case BPF_JLT:
				/* LT is unsigned '<', JB in x86 */
				jmp_cond = X86_JB;
				break;
990 991 992 993
			case BPF_JGE:
				/* GE is unsigned '>=', JAE in x86 */
				jmp_cond = X86_JAE;
				break;
994 995 996 997
			case BPF_JLE:
				/* LE is unsigned '<=', JBE in x86 */
				jmp_cond = X86_JBE;
				break;
998 999 1000 1001
			case BPF_JSGT:
				/* signed '>', GT in x86 */
				jmp_cond = X86_JG;
				break;
1002 1003 1004 1005
			case BPF_JSLT:
				/* signed '<', LT in x86 */
				jmp_cond = X86_JL;
				break;
1006 1007 1008 1009
			case BPF_JSGE:
				/* signed '>=', GE in x86 */
				jmp_cond = X86_JGE;
				break;
1010 1011 1012 1013
			case BPF_JSLE:
				/* signed '<=', LE in x86 */
				jmp_cond = X86_JLE;
				break;
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
			default: /* to silence gcc warning */
				return -EFAULT;
			}
			jmp_offset = addrs[i + insn->off] - addrs[i];
			if (is_imm8(jmp_offset)) {
				EMIT2(jmp_cond, jmp_offset);
			} else if (is_simm32(jmp_offset)) {
				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
			} else {
				pr_err("cond_jmp gen bug %llx\n", jmp_offset);
				return -EFAULT;
			}

			break;
1028

1029
		case BPF_JMP | BPF_JA:
1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040
			if (insn->off == -1)
				/* -1 jmp instructions will always jump
				 * backwards two bytes. Explicitly handling
				 * this case avoids wasting too many passes
				 * when there are long sequences of replaced
				 * dead code.
				 */
				jmp_offset = -2;
			else
				jmp_offset = addrs[i + insn->off] - addrs[i];

1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058
			if (!jmp_offset)
				/* optimize out nop jumps */
				break;
emit_jmp:
			if (is_imm8(jmp_offset)) {
				EMIT2(0xEB, jmp_offset);
			} else if (is_simm32(jmp_offset)) {
				EMIT1_off32(0xE9, jmp_offset);
			} else {
				pr_err("jmp gen bug %llx\n", jmp_offset);
				return -EFAULT;
			}
			break;

		case BPF_LD | BPF_IND | BPF_W:
			func = sk_load_word;
			goto common_load;
		case BPF_LD | BPF_ABS | BPF_W:
1059
			func = CHOOSE_LOAD_FUNC(imm32, sk_load_word);
1060 1061
common_load:
			ctx->seen_ld_abs = seen_ld_abs = true;
1062 1063 1064
			jmp_offset = func - (image + addrs[i]);
			if (!func || !is_simm32(jmp_offset)) {
				pr_err("unsupported bpf func %d addr %p image %p\n",
1065
				       imm32, func, image);
1066 1067 1068 1069
				return -EINVAL;
			}
			if (BPF_MODE(insn->code) == BPF_ABS) {
				/* mov %esi, imm32 */
1070
				EMIT1_off32(0xBE, imm32);
1071
			} else {
1072 1073 1074 1075
				/* mov %rsi, src_reg */
				EMIT_mov(BPF_REG_2, src_reg);
				if (imm32) {
					if (is_imm8(imm32))
1076
						/* add %esi, imm8 */
1077
						EMIT3(0x83, 0xC6, imm32);
1078
					else
1079
						/* add %esi, imm32 */
1080
						EMIT2_off32(0x81, 0xC6, imm32);
1081
				}
1082 1083 1084 1085 1086 1087
			}
			/* skb pointer is in R6 (%rbx), it will be copied into
			 * %rdi if skb_copy_bits() call is necessary.
			 * sk_load_* helpers also use %r10 and %r9d.
			 * See bpf_jit.S
			 */
1088 1089 1090 1091
			if (seen_ax_reg)
				/* r10 = skb->data, mov %r10, off32(%rbx) */
				EMIT3_off32(0x4c, 0x8b, 0x93,
					    offsetof(struct sk_buff, data));
1092 1093 1094 1095 1096 1097 1098
			EMIT1_off32(0xE8, jmp_offset); /* call */
			break;

		case BPF_LD | BPF_IND | BPF_H:
			func = sk_load_half;
			goto common_load;
		case BPF_LD | BPF_ABS | BPF_H:
1099
			func = CHOOSE_LOAD_FUNC(imm32, sk_load_half);
1100 1101 1102 1103 1104
			goto common_load;
		case BPF_LD | BPF_IND | BPF_B:
			func = sk_load_byte;
			goto common_load;
		case BPF_LD | BPF_ABS | BPF_B:
1105
			func = CHOOSE_LOAD_FUNC(imm32, sk_load_byte);
1106 1107 1108
			goto common_load;

		case BPF_JMP | BPF_EXIT:
1109
			if (seen_exit) {
1110 1111 1112
				jmp_offset = ctx->cleanup_addr - addrs[i];
				goto emit_jmp;
			}
1113
			seen_exit = true;
1114 1115
			/* update cleanup_addr */
			ctx->cleanup_addr = proglen;
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126
			/* mov rbx, qword ptr [rbp+0] */
			EMIT4(0x48, 0x8B, 0x5D, 0);
			/* mov r13, qword ptr [rbp+8] */
			EMIT4(0x4C, 0x8B, 0x6D, 8);
			/* mov r14, qword ptr [rbp+16] */
			EMIT4(0x4C, 0x8B, 0x75, 16);
			/* mov r15, qword ptr [rbp+24] */
			EMIT4(0x4C, 0x8B, 0x7D, 24);

			/* add rbp, AUX_STACK_SPACE */
			EMIT4(0x48, 0x83, 0xC5, AUX_STACK_SPACE);
1127 1128 1129 1130
			EMIT1(0xC9); /* leave */
			EMIT1(0xC3); /* ret */
			break;

1131
		default:
1132 1133 1134
			/* By design x64 JIT should support all BPF instructions
			 * This error will be seen if new instruction was added
			 * to interpreter, but not to JIT
1135
			 * or if there is junk in bpf_prog
1136 1137
			 */
			pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1138 1139
			return -EINVAL;
		}
1140

1141
		ilen = prog - temp;
1142
		if (ilen > BPF_MAX_INSN_SIZE) {
1143
			pr_err("bpf_jit: fatal insn size error\n");
1144 1145 1146
			return -EFAULT;
		}

1147 1148
		if (image) {
			if (unlikely(proglen + ilen > oldproglen)) {
1149
				pr_err("bpf_jit: fatal error\n");
1150
				return -EFAULT;
1151
			}
1152
			memcpy(image + proglen, temp, ilen);
1153
		}
1154 1155 1156 1157 1158 1159 1160
		proglen += ilen;
		addrs[i] = proglen;
		prog = temp;
	}
	return proglen;
}

1161 1162 1163 1164 1165 1166 1167 1168
struct x64_jit_data {
	struct bpf_binary_header *header;
	int *addrs;
	u8 *image;
	int proglen;
	struct jit_context ctx;
};

1169
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1170 1171
{
	struct bpf_binary_header *header = NULL;
1172
	struct bpf_prog *tmp, *orig_prog = prog;
1173
	struct x64_jit_data *jit_data;
1174 1175
	int proglen, oldproglen = 0;
	struct jit_context ctx = {};
1176
	bool tmp_blinded = false;
1177
	bool extra_pass = false;
1178 1179 1180 1181 1182
	u8 *image = NULL;
	int *addrs;
	int pass;
	int i;

1183
	if (!prog->jit_requested)
1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
		return orig_prog;

	tmp = bpf_jit_blind_constants(prog);
	/* If blinding was requested and we failed during blinding,
	 * we must fall back to the interpreter.
	 */
	if (IS_ERR(tmp))
		return orig_prog;
	if (tmp != prog) {
		tmp_blinded = true;
		prog = tmp;
	}
1196

1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214
	jit_data = prog->aux->jit_data;
	if (!jit_data) {
		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
		if (!jit_data) {
			prog = orig_prog;
			goto out;
		}
		prog->aux->jit_data = jit_data;
	}
	addrs = jit_data->addrs;
	if (addrs) {
		ctx = jit_data->ctx;
		oldproglen = jit_data->proglen;
		image = jit_data->image;
		header = jit_data->header;
		extra_pass = true;
		goto skip_init_addrs;
	}
1215
	addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
1216 1217
	if (!addrs) {
		prog = orig_prog;
1218
		goto out_addrs;
1219
	}
1220 1221 1222 1223 1224 1225 1226 1227 1228

	/* Before first pass, make a rough estimation of addrs[]
	 * each bpf instruction is translated to less than 64 bytes
	 */
	for (proglen = 0, i = 0; i < prog->len; i++) {
		proglen += 64;
		addrs[i] = proglen;
	}
	ctx.cleanup_addr = proglen;
1229
skip_init_addrs:
1230

1231 1232 1233 1234 1235
	/* JITed image shrinks with every pass and the loop iterates
	 * until the image stops shrinking. Very large bpf programs
	 * may converge on the last pass. In such case do one more
	 * pass to emit the final image
	 */
1236
	for (pass = 0; pass < 20 || image; pass++) {
1237 1238
		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
		if (proglen <= 0) {
1239
out_image:
1240 1241
			image = NULL;
			if (header)
1242
				bpf_jit_binary_free(header);
1243 1244
			prog = orig_prog;
			goto out_addrs;
1245
		}
1246
		if (image) {
1247
			if (proglen != oldproglen) {
1248 1249
				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
				       proglen, oldproglen);
1250
				goto out_image;
1251
			}
1252 1253 1254
			break;
		}
		if (proglen == oldproglen) {
1255 1256
			header = bpf_jit_binary_alloc(proglen, &image,
						      1, jit_fill_hole);
1257 1258 1259 1260
			if (!header) {
				prog = orig_prog;
				goto out_addrs;
			}
1261 1262
		}
		oldproglen = proglen;
1263
		cond_resched();
1264
	}
1265

1266
	if (bpf_jit_enable > 1)
1267
		bpf_jit_dump(prog->len, proglen, pass + 1, image);
1268 1269

	if (image) {
1270 1271 1272 1273 1274 1275 1276 1277 1278
		if (!prog->is_func || extra_pass) {
			bpf_jit_binary_lock_ro(header);
		} else {
			jit_data->addrs = addrs;
			jit_data->ctx = ctx;
			jit_data->proglen = proglen;
			jit_data->image = image;
			jit_data->header = header;
		}
1279
		prog->bpf_func = (void *)image;
1280
		prog->jited = 1;
1281
		prog->jited_len = proglen;
1282 1283
	} else {
		prog = orig_prog;
1284
	}
1285

1286
	if (!image || !prog->is_func || extra_pass) {
1287
out_addrs:
1288 1289 1290 1291
		kfree(addrs);
		kfree(jit_data);
		prog->aux->jit_data = NULL;
	}
1292 1293 1294 1295
out:
	if (tmp_blinded)
		bpf_jit_prog_release_other(prog, prog == orig_prog ?
					   tmp : orig_prog);
1296
	return prog;
1297
}