Re: [PATCH bpf-next v4 2/4] riscv, bpf: add RV32G eBPF JIT

From: BjÃrn TÃpel
Date: Tue Mar 03 2020 - 02:48:19 EST


On Tue, 3 Mar 2020 at 01:50, Luke Nelson <lukenels@xxxxxxxxxxxxxxxxx> wrote:
>
> This is an eBPF JIT for RV32G, adapted from the JIT for RV64G and
> the 32-bit ARM JIT.
>
> There are two main changes required for this to work compared to
> the RV64 JIT.
>
> First, eBPF registers are 64-bit, while RV32G registers are 32-bit.
> BPF registers either map directly to 2 RISC-V registers, or reside
> in stack scratch space and are saved and restored when used.
>
> Second, many 64-bit ALU operations do not trivially map to 32-bit
> operations. Operations that move bits between high and low words,
> such as ADD, LSH, MUL, and others must emulate the 64-bit behavior
> in terms of 32-bit instructions.
>
> Supported features:
>
> The RV32 JIT supports the same features and instructions as the
> RV64 JIT, with the following exceptions:
>
> - ALU64 DIV/MOD: Requires loops to implement on 32-bit hardware.
>
> - BPF_XADD | BPF_DW: There's no 8-byte atomic instruction in RV32.
>
> These features are also unsupported on other BPF JITs for 32-bit
> architectures.
>
> Testing:
>
> - lib/test_bpf.c
> test_bpf: Summary: 378 PASSED, 0 FAILED, [349/366 JIT'ed]

Which are the tests that fail to JIT, and is that due to div/mod +
xadd?

> test_bpf: test_skb_segment: Summary: 2 PASSED, 0 FAILED
>
> - tools/testing/selftests/bpf/test_verifier.c
> Summary: 1415 PASSED, 122 SKIPPED, 43 FAILED
>
> Tested both with and without BPF JIT hardening.
>
> This is the same set of tests that pass using the BPF interpreter
> with the JIT disabled.
>
> Verification and synthesis:
>
> We developed the RV32 JIT using our automated verification tool,
> Serval. We have used Serval in the past to verify patches to the
> RV64 JIT. We also used Serval to superoptimize the resulting code
> through program synthesis.
>
> You can find the tool and a guide to the approach and results here:
> https://github.com/uw-unsat/serval-bpf/tree/rv32-jit-v4
>

Very interesting work!

> Co-developed-by: Xi Wang <xi.wang@xxxxxxxxx>
> Signed-off-by: Xi Wang <xi.wang@xxxxxxxxx>
> Signed-off-by: Luke Nelson <luke.r.nels@xxxxxxxxx>
> ---
> arch/riscv/Kconfig | 2 +-
> arch/riscv/net/Makefile | 7 +-
> arch/riscv/net/bpf_jit_comp32.c | 1466 +++++++++++++++++++++++++++++++
> 3 files changed, 1473 insertions(+), 2 deletions(-)
> create mode 100644 arch/riscv/net/bpf_jit_comp32.c
[...]
> +
> +static const s8 *rv32_bpf_get_reg64(const s8 *reg, const s8 *tmp,
> + struct rv_jit_context *ctx)

Really a nit, but you're using rv32 as prefix, and also as part of
many of the functions (e.g. emit_rv32). Everything is this file is
just for RV32, so maybe remove that implicit information from the
function name? Just a thought! :-)

> +{
> + if (is_stacked(hi(reg))) {
> + emit(rv_lw(hi(tmp), hi(reg), RV_REG_FP), ctx);
> + emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
> + reg = tmp;
> + }
> + return reg;
> +}
> +
> +static void rv32_bpf_put_reg64(const s8 *reg, const s8 *src,
> + struct rv_jit_context *ctx)
> +{
> + if (is_stacked(hi(reg))) {
> + emit(rv_sw(RV_REG_FP, hi(reg), hi(src)), ctx);
> + emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
> + }
> +}
> +
> +static const s8 *rv32_bpf_get_reg32(const s8 *reg, const s8 *tmp,
> + struct rv_jit_context *ctx)
> +{
> + if (is_stacked(lo(reg))) {
> + emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
> + reg = tmp;
> + }
> + return reg;
> +}
> +
> +static void rv32_bpf_put_reg32(const s8 *reg, const s8 *src,
> + struct rv_jit_context *ctx)
> +{
> + if (is_stacked(lo(reg))) {
> + emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
> + if (!ctx->prog->aux->verifier_zext)
> + emit(rv_sw(RV_REG_FP, hi(reg), RV_REG_ZERO), ctx);
> + } else if (!ctx->prog->aux->verifier_zext) {
> + emit(rv_addi(hi(reg), RV_REG_ZERO, 0), ctx);
> + }
> +}
> +
> +static void emit_jump_and_link(u8 rd, s32 rvoff, bool force_jalr,
> + struct rv_jit_context *ctx)
> +{
> + s32 upper, lower;
> +
> + if (rvoff && is_21b_int(rvoff) && !force_jalr) {
> + emit(rv_jal(rd, rvoff >> 1), ctx);
> + return;
> + }
> +
> + upper = (rvoff + (1 << 11)) >> 12;
> + lower = rvoff & 0xfff;
> + emit(rv_auipc(RV_REG_T1, upper), ctx);
> + emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
> +}
> +
> +static void emit_rv32_alu_i64(const s8 *dst, s32 imm,
> + struct rv_jit_context *ctx, const u8 op)
> +{
> + const s8 *tmp1 = bpf2rv32[TMP_REG_1];
> + const s8 *rd = rv32_bpf_get_reg64(dst, tmp1, ctx);
> +
> + switch (op) {
> + case BPF_MOV:
> + emit_imm32(rd, imm, ctx);
> + break;
> + case BPF_AND:
> + if (is_12b_int(imm)) {
> + emit(rv_andi(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + if (imm >= 0)
> + emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
> + break;
> + case BPF_OR:
> + if (is_12b_int(imm)) {
> + emit(rv_ori(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + if (imm < 0)
> + emit(rv_ori(hi(rd), RV_REG_ZERO, -1), ctx);
> + break;
> + case BPF_XOR:
> + if (is_12b_int(imm)) {
> + emit(rv_xori(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + if (imm < 0)
> + emit(rv_xori(hi(rd), hi(rd), -1), ctx);
> + break;
> + case BPF_LSH:
> + if (imm >= 32) {
> + emit(rv_slli(hi(rd), lo(rd), imm - 32), ctx);
> + emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
> + } else if (imm == 0) {
> + /* nop */

Can we get rid of this, and just do if/else if?

> + } else {
> + emit(rv_srli(RV_REG_T0, lo(rd), 32 - imm), ctx);
> + emit(rv_slli(hi(rd), hi(rd), imm), ctx);
> + emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
> + emit(rv_slli(lo(rd), lo(rd), imm), ctx);
> + }
> + break;
> + case BPF_RSH:
> + if (imm >= 32) {
> + emit(rv_srli(lo(rd), hi(rd), imm - 32), ctx);
> + emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
> + } else if (imm == 0) {
> + /* nop */

Dito.

> + } else {
> + emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
> + emit(rv_srli(lo(rd), lo(rd), imm), ctx);
> + emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
> + emit(rv_srli(hi(rd), hi(rd), imm), ctx);
> + }
> + break;
> + case BPF_ARSH:
> + if (imm >= 32) {
> + emit(rv_srai(lo(rd), hi(rd), imm - 32), ctx);
> + emit(rv_srai(hi(rd), hi(rd), 31), ctx);
> + } else if (imm == 0) {
> + /* nop */

Dito.

> + } else {
> + emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
> + emit(rv_srli(lo(rd), lo(rd), imm), ctx);
> + emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
> + emit(rv_srai(hi(rd), hi(rd), imm), ctx);
> + }
> + break;
> + }
> +
> + rv32_bpf_put_reg64(dst, rd, ctx);
> +}
> +
> +static void emit_rv32_alu_i32(const s8 *dst, s32 imm,
> + struct rv_jit_context *ctx, const u8 op)
> +{
> + const s8 *tmp1 = bpf2rv32[TMP_REG_1];
> + const s8 *rd = rv32_bpf_get_reg32(dst, tmp1, ctx);
> +
> + switch (op) {
> + case BPF_MOV:
> + emit_imm(lo(rd), imm, ctx);
> + break;
> + case BPF_ADD:
> + if (is_12b_int(imm)) {
> + emit(rv_addi(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_add(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_SUB:
> + if (is_12b_int(-imm)) {
> + emit(rv_addi(lo(rd), lo(rd), -imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_sub(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_AND:
> + if (is_12b_int(imm)) {
> + emit(rv_andi(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_OR:
> + if (is_12b_int(imm)) {
> + emit(rv_ori(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_XOR:
> + if (is_12b_int(imm)) {
> + emit(rv_xori(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_LSH:
> + if (is_12b_int(imm)) {
> + emit(rv_slli(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_sll(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_RSH:
> + if (is_12b_int(imm)) {
> + emit(rv_srli(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_srl(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;
> + case BPF_ARSH:
> + if (is_12b_int(imm)) {
> + emit(rv_srai(lo(rd), lo(rd), imm), ctx);
> + } else {
> + emit_imm(RV_REG_T0, imm, ctx);
> + emit(rv_sra(lo(rd), lo(rd), RV_REG_T0), ctx);
> + }
> + break;

Again nit; I like "early exit" code if possible. Instead of:

if (bleh) {
foo();
} else {
bar();
}

do:

if (bleh) {
foo()
return/break;
}
bar();

I find the latter easier to read -- but really a nit, and a matter of
style. There are number of places where that could be applied in the
file.


> + }
[...]
> +
> +static int build_body(struct rv_jit_context *ctx, bool extra_pass, int *offset)
> +{
> + const struct bpf_prog *prog = ctx->prog;
> + int i;
> +
> + for (i = 0; i < prog->len; i++) {
> + const struct bpf_insn *insn = &prog->insnsi[i];
> + int ret;
> +
> + ret = emit_insn(insn, ctx, extra_pass);
> + if (ret > 0)
> + /*
> + * BPF_LD | BPF_IMM | BPF_DW:
> + * Skip the next instruction.
> + */
> + i++;
> + if (offset)
> + offset[i] = ctx->ninsns;
> + if (ret < 0)
> + return ret;
> + }
> + return 0;
> +}
> +
> +bool bpf_jit_needs_zext(void)
> +{
> + return true;
> +}
> +
> +struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
> +{
> + bool tmp_blinded = false, extra_pass = false;
> + struct bpf_prog *tmp, *orig_prog = prog;
> + int pass = 0, prev_ninsns = 0, i;
> + struct rv_jit_data *jit_data;
> + struct rv_jit_context *ctx;
> + unsigned int image_size = 0;
> +
> + if (!prog->jit_requested)
> + return orig_prog;
> +
> + tmp = bpf_jit_blind_constants(prog);
> + if (IS_ERR(tmp))
> + return orig_prog;
> + if (tmp != prog) {
> + tmp_blinded = true;
> + prog = tmp;
> + }
> +
> + jit_data = prog->aux->jit_data;
> + if (!jit_data) {
> + jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
> + if (!jit_data) {
> + prog = orig_prog;
> + goto out;
> + }
> + prog->aux->jit_data = jit_data;
> + }
> +
> + ctx = &jit_data->ctx;
> +
> + if (ctx->offset) {
> + extra_pass = true;
> + image_size = sizeof(u32) * ctx->ninsns;
> + goto skip_init_ctx;
> + }
> +
> + ctx->prog = prog;
> + ctx->offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
> + if (!ctx->offset) {
> + prog = orig_prog;
> + goto out_offset;
> + }
> + for (i = 0; i < prog->len; i++) {
> + prev_ninsns += 32;
> + ctx->offset[i] = prev_ninsns;
> + }
> +
> + for (i = 0; i < NR_JIT_ITERATIONS; i++) {
> + pass++;
> + ctx->ninsns = 0;
> + if (build_body(ctx, extra_pass, ctx->offset)) {
> + prog = orig_prog;
> + goto out_offset;
> + }
> + build_prologue(ctx);
> + ctx->epilogue_offset = ctx->ninsns;
> + build_epilogue(ctx);
> +
> + if (ctx->ninsns == prev_ninsns) {
> + if (jit_data->header)
> + break;
> +
> + image_size = sizeof(u32) * ctx->ninsns;
> + jit_data->header =
> + bpf_jit_binary_alloc(image_size,
> + &jit_data->image,
> + sizeof(u32),
> + bpf_fill_ill_insns);
> + if (!jit_data->header) {
> + prog = orig_prog;
> + goto out_offset;
> + }
> +
> + ctx->insns = (u32 *)jit_data->image;
> + }
> + prev_ninsns = ctx->ninsns;
> + }
> +
> + if (i == NR_JIT_ITERATIONS) {
> + pr_err("bpf-jit: image did not converge in <%d passes!\n", i);
> + bpf_jit_binary_free(jit_data->header);
> + prog = orig_prog;
> + goto out_offset;
> + }
> +
> +skip_init_ctx:
> + pass++;
> + ctx->ninsns = 0;
> +
> + build_prologue(ctx);
> + if (build_body(ctx, extra_pass, NULL)) {
> + bpf_jit_binary_free(jit_data->header);
> + prog = orig_prog;
> + goto out_offset;
> + }
> + build_epilogue(ctx);
> +
> + if (bpf_jit_enable > 1)
> + bpf_jit_dump(prog->len, image_size, 2, ctx->insns);
> +
> + prog->bpf_func = (void *)ctx->insns;
> + prog->jited = 1;
> + prog->jited_len = image_size;
> +
> + bpf_flush_icache(jit_data->header, ctx->insns + ctx->ninsns);
> +
> + if (!prog->is_func || extra_pass) {
> +out_offset:
> + kfree(ctx->offset);
> + kfree(jit_data);
> + prog->aux->jit_data = NULL;
> + }
> +out:
> +
> + if (tmp_blinded)
> + bpf_jit_prog_release_other(prog, prog == orig_prog ?
> + tmp : orig_prog);
> + return prog;
> +}

At this point of the series, let's introduce the shared code .c-file
containing implementation for bpf_int_jit_compile() (with build_body
part of that)and bpf_jit_needs_zext(). That will make it easier to
catch bugs in both JITs and to avoid code duplication! Also, when
adding the stronger invariant suggested by Palmer [1], we only need to
do it in one place.

The pull out refactoring can be a separate commit.


Thanks!
BjÃrn

[1] https://lore.kernel.org/bpf/mhng-6be38b2a-78df-4016-aaea-f35aa0acd7e0@palmerdabbelt-glaptop/