提交 75fc792d 编写于 作者: T tensor-tang

fix when table width larger than 64

test=develop
上级 40402d5e
......@@ -312,7 +312,7 @@ void BenchEmbSeqPoolKernel() {
const T* table_data = table.data<T>();
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 9, 13, 16}) {
int64_t out_w = tbl_w * idx_w;
jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
type);
......
......@@ -53,7 +53,6 @@ void EmbSeqPoolJitCode::genCode() {
xor_(reg_idx_w_i_in_byte, reg_idx_w_i_in_byte);
mov(reg_ptr_dst_i, reg_ptr_param_dst);
add(reg_ptr_dst_i, acc_num_regs * block_size);
add(param_tbl, acc_num_regs * block_size);
L(l_next_idx_w);
{
......@@ -113,8 +112,10 @@ void EmbSeqPoolJitCode::genCode() {
cmp(reg_idx_w_i_in_byte, reg_idx_width_in_byte);
jl(l_next_idx_w, T_NEAR);
} // end of idx w
acc_num_regs += num_regs;
} // end of groups
add(param_tbl, num_regs * block_size); // do not use acc_num_regs
} // end of groups
postCode();
}
......
......@@ -625,7 +625,7 @@ void TestEmbSeqPoolKernel() {
const T* table_data = table.data();
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 9, 13, 16}) {
auto ref = jit::GetRefer<KT, jit::EmbSeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<int64_t> idx(idx_h * idx_w);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册