提交 9ee77a23 编写于 作者: Z zhangwen31

[arm][kernel][test] test: fix equality chk in elementwise_op test test=develop

上级 684c9197
...@@ -268,6 +268,14 @@ template void elementwise_imod_compute_ref<int32_t>( ...@@ -268,6 +268,14 @@ template void elementwise_imod_compute_ref<int32_t>(
template void elementwise_imod_compute_ref<int64_t>( template void elementwise_imod_compute_ref<int64_t>(
const operators::ElementwiseParam& param, const std::string act_type); const operators::ElementwiseParam& param, const std::string act_type);
template <class T>
bool is_fp_close(T v1, T v2, T rel_tol = 1e-4, T abs_tol = 1e-5) {
bool abs_chk = std::abs(v1 - v2) < abs_tol;
bool rel_chk =
(std::abs(v1 - v2) / std::min(std::abs(v1), std::abs(v2))) < rel_tol;
return abs_chk || rel_chk;
}
template <template <class, PrecisionType> class ElementWiseComputeTemplate, template <template <class, PrecisionType> class ElementWiseComputeTemplate,
typename T, typename T,
PrecisionType PType> PrecisionType PType>
...@@ -345,13 +353,14 @@ void do_elementwise_compute(const char* op_type_str) { ...@@ -345,13 +353,14 @@ void do_elementwise_compute(const char* op_type_str) {
elementwise_compute_ref<T>(param, op_type_str, ""); elementwise_compute_ref<T>(param, op_type_str, "");
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
ASSERT_NEAR(output_data[i], output_ref_data[i], 1e-5) ASSERT_EQ(is_fp_close(output_data[i], output_ref_data[i]),
<< "Value differ at index " << i; true)
<< op_type_str << "Value differ at index " << i;
} }
} else { } else {
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
ASSERT_EQ(output_data[i], output_ref_data[i]) ASSERT_EQ(output_data[i], output_ref_data[i])
<< "Value differ at index " << i; << op_type_str << "Value differ at index " << i;
} }
} }
} }
...@@ -362,7 +371,7 @@ void do_elementwise_compute(const char* op_type_str) { ...@@ -362,7 +371,7 @@ void do_elementwise_compute(const char* op_type_str) {
} }
} }
TEST(elementwise_add, compute_fp32) { TEST(elementwise_op, compute_fp32) {
do_elementwise_compute<ElementwiseAddCompute, float, PRECISION(kFloat)>( do_elementwise_compute<ElementwiseAddCompute, float, PRECISION(kFloat)>(
"add"); "add");
do_elementwise_compute<ElementwiseSubCompute, float, PRECISION(kFloat)>( do_elementwise_compute<ElementwiseSubCompute, float, PRECISION(kFloat)>(
...@@ -376,7 +385,7 @@ TEST(elementwise_add, compute_fp32) { ...@@ -376,7 +385,7 @@ TEST(elementwise_add, compute_fp32) {
} }
} }
TEST(elementwise_add, compute_i32) { TEST(elementwise_op, compute_i32) {
do_elementwise_compute<ElementwiseAddCompute, int32_t, PRECISION(kInt32)>( do_elementwise_compute<ElementwiseAddCompute, int32_t, PRECISION(kInt32)>(
"add"); "add");
do_elementwise_compute<ElementwiseSubCompute, int32_t, PRECISION(kInt32)>( do_elementwise_compute<ElementwiseSubCompute, int32_t, PRECISION(kInt32)>(
...@@ -390,7 +399,7 @@ TEST(elementwise_add, compute_i32) { ...@@ -390,7 +399,7 @@ TEST(elementwise_add, compute_i32) {
} }
} }
TEST(elementwise_add, compute_i64) { TEST(elementwise_op, compute_i64) {
do_elementwise_compute<ElementwiseAddCompute, int64_t, PRECISION(kInt64)>( do_elementwise_compute<ElementwiseAddCompute, int64_t, PRECISION(kInt64)>(
"add"); "add");
do_elementwise_compute<ElementwiseSubCompute, int64_t, PRECISION(kInt64)>( do_elementwise_compute<ElementwiseSubCompute, int64_t, PRECISION(kInt64)>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册