diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 9831b6ef922390005a7f81fdc75d58279097baa1..96196d26a80d29b017e5e8b1e563fa3120f65ce0 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -312,7 +312,7 @@ void BenchEmbSeqPoolKernel() { const T* table_data = table.data(); for (auto type : pool_types) { for (int idx_w : {1, 2, 10, 16}) { - for (int idx_h : {1, 2, 10, 16}) { + for (int idx_h : {1, 2, 9, 13, 16}) { int64_t out_w = tbl_w * idx_w; jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w, type); diff --git a/paddle/fluid/operators/jit/gen/embseqpool.cc b/paddle/fluid/operators/jit/gen/embseqpool.cc index 3f233acee90bdc6d886c8cbc07dc31890f4863bc..23837a3fb9886ae8a839d4b31bd57916168ea53c 100644 --- a/paddle/fluid/operators/jit/gen/embseqpool.cc +++ b/paddle/fluid/operators/jit/gen/embseqpool.cc @@ -53,7 +53,6 @@ void EmbSeqPoolJitCode::genCode() { xor_(reg_idx_w_i_in_byte, reg_idx_w_i_in_byte); mov(reg_ptr_dst_i, reg_ptr_param_dst); add(reg_ptr_dst_i, acc_num_regs * block_size); - add(param_tbl, acc_num_regs * block_size); L(l_next_idx_w); { @@ -113,8 +112,10 @@ void EmbSeqPoolJitCode::genCode() { cmp(reg_idx_w_i_in_byte, reg_idx_width_in_byte); jl(l_next_idx_w, T_NEAR); } // end of idx w + acc_num_regs += num_regs; - } // end of groups + add(param_tbl, num_regs * block_size); // do not use acc_num_regs + } // end of groups postCode(); } diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index c35b6aef232c44e9f08a8b4569305186e98c7ff7..15e29938240ddb95db1488dd51aed5b07b6ba8ae 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -625,7 +625,7 @@ void TestEmbSeqPoolKernel() { const T* table_data = table.data(); for (auto type : pool_types) { for (int idx_w : {1, 2, 10, 16}) { - for (int idx_h : {1, 2, 10, 16}) { + for (int idx_h : {1, 2, 9, 13, 16}) { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector idx(idx_h * idx_w);