未验证 提交 42cefe1b 编写于 作者: H hong19860320 提交者: GitHub

[UT] Fix the ut of search_aligned_mat_mul and search_seq_fc beacuse of 0*nan=nan (#4119)

test=develop test=xpu
上级 192be07b
...@@ -103,6 +103,11 @@ class SearchAlignedMatMulComputeTester : public arena::TestCase { ...@@ -103,6 +103,11 @@ class SearchAlignedMatMulComputeTester : public arena::TestCase {
out->Resize(out_dims); out->Resize(out_dims);
auto out_data = out->mutable_data<float>(); auto out_data = out->mutable_data<float>();
// Prevent 0*nan=nan in basic_gemm
int64_t out_num = out_dims.production();
for (int64_t i = 0; i < out_num; i++) {
out_data[i] = 0;
}
for (int i = 0; i < seq_num; i++) { for (int i = 0; i < seq_num; i++) {
basic_gemm<float, float>(x_transpose_, basic_gemm<float, float>(x_transpose_,
y_transpose_, y_transpose_,
......
...@@ -87,12 +87,18 @@ class SearchSeqFcOPTest : public arena::TestCase { ...@@ -87,12 +87,18 @@ class SearchSeqFcOPTest : public arena::TestCase {
} }
out->set_lod(x_lod); out->set_lod(x_lod);
out->Resize({x_dims[0], w_dims[0]}); DDim out_dims({x_dims[0], w_dims[0]});
out->Resize(out_dims);
int M = x_dims[0]; int M = x_dims[0];
int K = x_dims[1]; int K = x_dims[1];
int N = w_dims[0]; int N = w_dims[0];
auto out_data = out->mutable_data<float>(); auto out_data = out->mutable_data<float>();
// Prevent 0*nan=nan in basic_gemm
int64_t out_num = out_dims.production();
for (int64_t i = 0; i < out_num; i++) {
out_data[i] = 0;
}
basic_gemm<float, float>(false, basic_gemm<float, float>(false,
true, true,
M, M,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册