[patch 37/38] x86/bpf: Emit call depth accounting if required

From: Thomas Gleixner
Date: Sat Jul 16 2022 - 19:20:14 EST


Ensure that calls in BPF jitted programs are emitting call depth accounting
when enabled to keep the call/return balanced. The return thunk jump is
already injected due to the earlier retbleed mitigations.

Signed-off-by: Thomas Gleixner <tglx@xxxxxxxxxxxxx>
Cc: Alexei Starovoitov <ast@xxxxxxxxxx>
Cc: Daniel Borkmann <daniel@xxxxxxxxxxxxx>
---
arch/x86/include/asm/alternative.h | 6 +++++
arch/x86/kernel/callthunks.c | 19 ++++++++++++++++
arch/x86/net/bpf_jit_comp.c | 43 ++++++++++++++++++++++++-------------
3 files changed, 53 insertions(+), 15 deletions(-)

--- a/arch/x86/include/asm/alternative.h
+++ b/arch/x86/include/asm/alternative.h
@@ -95,6 +95,7 @@ extern void callthunks_patch_module_call
extern void callthunks_module_free(struct module *mod);
extern void *callthunks_translate_call_dest(void *dest);
extern bool is_callthunk(void *addr);
+extern int x86_call_depth_emit_accounting(u8 **pprog, void *func);
#else
static __always_inline void callthunks_patch_builtin_calls(void) {}
static __always_inline void
@@ -109,6 +110,11 @@ static __always_inline bool is_callthunk
{
return false;
}
+static __always_inline int x86_call_depth_emit_accounting(u8 **pprog,
+ void *func)
+{
+ return 0;
+}
#endif

#ifdef CONFIG_SMP
--- a/arch/x86/kernel/callthunks.c
+++ b/arch/x86/kernel/callthunks.c
@@ -706,6 +706,25 @@ int callthunk_get_kallsym(unsigned int s
return ret;
}

+#ifdef CONFIG_BPF_JIT
+int x86_call_depth_emit_accounting(u8 **pprog, void *func)
+{
+ unsigned int tmpl_size = callthunk_desc.template_size;
+ void *tmpl = callthunk_desc.template;
+
+ if (!thunks_initialized)
+ return 0;
+
+ /* Is function call target a thunk? */
+ if (is_callthunk(func))
+ return 0;
+
+ memcpy(*pprog, tmpl, tmpl_size);
+ *pprog += tmpl_size;
+ return tmpl_size;
+}
+#endif
+
#ifdef CONFIG_MODULES
void noinline callthunks_patch_module_calls(struct callthunk_sites *cs,
struct module *mod)
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -340,6 +340,12 @@ static int emit_call(u8 **pprog, void *f
return emit_patch(pprog, func, ip, 0xE8);
}

+static int emit_rsb_call(u8 **pprog, void *func, void *ip)
+{
+ x86_call_depth_emit_accounting(pprog, func);
+ return emit_patch(pprog, func, ip, 0xE8);
+}
+
static int emit_jump(u8 **pprog, void *func, void *ip)
{
return emit_patch(pprog, func, ip, 0xE9);
@@ -1431,19 +1437,26 @@ st: if (is_imm8(insn->off))
break;

/* call */
- case BPF_JMP | BPF_CALL:
+ case BPF_JMP | BPF_CALL: {
+ int offs;
+
func = (u8 *) __bpf_call_base + imm32;
if (tail_call_reachable) {
/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
EMIT3_off32(0x48, 0x8B, 0x85,
-round_up(bpf_prog->aux->stack_depth, 8) - 8);
- if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
+ if (!imm32)
return -EINVAL;
+ offs = 7 + x86_call_depth_emit_accounting(&prog, func);
} else {
- if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
+ if (!imm32)
return -EINVAL;
+ offs = x86_call_depth_emit_accounting(&prog, func);
}
+ if (emit_call(&prog, func, image + addrs[i - 1] + offs))
+ return -EINVAL;
break;
+ }

case BPF_JMP | BPF_TAIL_CALL:
if (imm32)
@@ -1808,10 +1821,10 @@ static int invoke_bpf_prog(const struct
/* arg2: lea rsi, [rbp - ctx_cookie_off] */
EMIT4(0x48, 0x8D, 0x75, -run_ctx_off);

- if (emit_call(&prog,
- p->aux->sleepable ? __bpf_prog_enter_sleepable :
- __bpf_prog_enter, prog))
- return -EINVAL;
+ if (emit_rsb_call(&prog,
+ p->aux->sleepable ? __bpf_prog_enter_sleepable :
+ __bpf_prog_enter, prog))
+ return -EINVAL;
/* remember prog start time returned by __bpf_prog_enter */
emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);

@@ -1831,7 +1844,7 @@ static int invoke_bpf_prog(const struct
(long) p->insnsi >> 32,
(u32) (long) p->insnsi);
/* call JITed bpf program or interpreter */
- if (emit_call(&prog, p->bpf_func, prog))
+ if (emit_rsb_call(&prog, p->bpf_func, prog))
return -EINVAL;

/*
@@ -1855,10 +1868,10 @@ static int invoke_bpf_prog(const struct
emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
/* arg3: lea rdx, [rbp - run_ctx_off] */
EMIT4(0x48, 0x8D, 0x55, -run_ctx_off);
- if (emit_call(&prog,
- p->aux->sleepable ? __bpf_prog_exit_sleepable :
- __bpf_prog_exit, prog))
- return -EINVAL;
+ if (emit_rsb_call(&prog,
+ p->aux->sleepable ? __bpf_prog_exit_sleepable :
+ __bpf_prog_exit, prog))
+ return -EINVAL;

*pprog = prog;
return 0;
@@ -2123,7 +2136,7 @@ int arch_prepare_bpf_trampoline(struct b
if (flags & BPF_TRAMP_F_CALL_ORIG) {
/* arg1: mov rdi, im */
emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
- if (emit_call(&prog, __bpf_tramp_enter, prog)) {
+ if (emit_rsb_call(&prog, __bpf_tramp_enter, prog)) {
ret = -EINVAL;
goto cleanup;
}
@@ -2151,7 +2164,7 @@ int arch_prepare_bpf_trampoline(struct b
restore_regs(m, &prog, nr_args, regs_off);

/* call original function */
- if (emit_call(&prog, orig_call, prog)) {
+ if (emit_rsb_call(&prog, orig_call, prog)) {
ret = -EINVAL;
goto cleanup;
}
@@ -2194,7 +2207,7 @@ int arch_prepare_bpf_trampoline(struct b
im->ip_epilogue = prog;
/* arg1: mov rdi, im */
emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
- if (emit_call(&prog, __bpf_tramp_exit, prog)) {
+ if (emit_rsb_call(&prog, __bpf_tramp_exit, prog)) {
ret = -EINVAL;
goto cleanup;
}