提交 75fc792d 编写于 作者: T tensor-tang

fix when table width larger than 64

test=develop
上级 40402d5e
...@@ -312,7 +312,7 @@ void BenchEmbSeqPoolKernel() { ...@@ -312,7 +312,7 @@ void BenchEmbSeqPoolKernel() {
const T* table_data = table.data<T>(); const T* table_data = table.data<T>();
for (auto type : pool_types) { for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) { for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 10, 16}) { for (int idx_h : {1, 2, 9, 13, 16}) {
int64_t out_w = tbl_w * idx_w; int64_t out_w = tbl_w * idx_w;
jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w, jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
type); type);
......
...@@ -53,7 +53,6 @@ void EmbSeqPoolJitCode::genCode() { ...@@ -53,7 +53,6 @@ void EmbSeqPoolJitCode::genCode() {
xor_(reg_idx_w_i_in_byte, reg_idx_w_i_in_byte); xor_(reg_idx_w_i_in_byte, reg_idx_w_i_in_byte);
mov(reg_ptr_dst_i, reg_ptr_param_dst); mov(reg_ptr_dst_i, reg_ptr_param_dst);
add(reg_ptr_dst_i, acc_num_regs * block_size); add(reg_ptr_dst_i, acc_num_regs * block_size);
add(param_tbl, acc_num_regs * block_size);
L(l_next_idx_w); L(l_next_idx_w);
{ {
...@@ -113,8 +112,10 @@ void EmbSeqPoolJitCode::genCode() { ...@@ -113,8 +112,10 @@ void EmbSeqPoolJitCode::genCode() {
cmp(reg_idx_w_i_in_byte, reg_idx_width_in_byte); cmp(reg_idx_w_i_in_byte, reg_idx_width_in_byte);
jl(l_next_idx_w, T_NEAR); jl(l_next_idx_w, T_NEAR);
} // end of idx w } // end of idx w
acc_num_regs += num_regs; acc_num_regs += num_regs;
} // end of groups add(param_tbl, num_regs * block_size); // do not use acc_num_regs
} // end of groups
postCode(); postCode();
} }
......
...@@ -625,7 +625,7 @@ void TestEmbSeqPoolKernel() { ...@@ -625,7 +625,7 @@ void TestEmbSeqPoolKernel() {
const T* table_data = table.data(); const T* table_data = table.data();
for (auto type : pool_types) { for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) { for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 10, 16}) { for (int idx_h : {1, 2, 9, 13, 16}) {
auto ref = jit::GetRefer<KT, jit::EmbSeqPoolTuples<T>>(); auto ref = jit::GetRefer<KT, jit::EmbSeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<int64_t> idx(idx_h * idx_w); std::vector<int64_t> idx(idx_h * idx_w);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册