提交 b059fb33 编写于 作者: Z zhangwen31

[arm][kernel]refactor: elementwise_add's compute test uses template now test=develop

add: i32 and i64 test case
上级 32b5caa0
...@@ -254,8 +254,9 @@ template void elementwise_imod_compute_ref<int32_t>( ...@@ -254,8 +254,9 @@ template void elementwise_imod_compute_ref<int32_t>(
template void elementwise_imod_compute_ref<int64_t>( template void elementwise_imod_compute_ref<int64_t>(
const operators::ElementwiseParam& param, const std::string act_type); const operators::ElementwiseParam& param, const std::string act_type);
TEST(elementwise_add, compute) { template <typename T, PrecisionType PType>
ElementwiseAddCompute<float, PRECISION(kFloat)> elementwise_add; void elementwise_add_compute() {
ElementwiseAddCompute<T, PType> elementwise_add;
operators::ElementwiseParam param; operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref; lite::Tensor x, y, output, output_ref;
...@@ -305,10 +306,10 @@ TEST(elementwise_add, compute) { ...@@ -305,10 +306,10 @@ TEST(elementwise_add, compute) {
y.Resize(y_dim); y.Resize(y_dim);
output.Resize(x_dim); output.Resize(x_dim);
output_ref.Resize(x_dim); output_ref.Resize(x_dim);
auto* x_data = x.mutable_data<float>(); T* x_data = x.mutable_data<T>();
auto* y_data = y.mutable_data<float>(); T* y_data = y.mutable_data<T>();
auto* output_data = output.mutable_data<float>(); T* output_data = output.mutable_data<T>();
auto* output_ref_data = output_ref.mutable_data<float>(); T* output_ref_data = output_ref.mutable_data<T>();
for (int i = 0; i < x_dim.production(); i++) { for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i; x_data[i] = i;
} }
...@@ -322,9 +323,17 @@ TEST(elementwise_add, compute) { ...@@ -322,9 +323,17 @@ TEST(elementwise_add, compute) {
elementwise_add.SetParam(param); elementwise_add.SetParam(param);
elementwise_add.Run(); elementwise_add.Run();
param.Out = &output_ref; param.Out = &output_ref;
elementwise_compute_ref<float>(param, "add", ""); elementwise_compute_ref<T>(param, "add", "");
if (std::is_floating_point<T>::value) {
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5)
<< "Value differ at index " << i;
}
} else {
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_EQ(output_data[i], output_ref_data[i])
<< "Value differ at index " << i;
}
} }
} }
} }
...@@ -334,6 +343,27 @@ TEST(elementwise_add, compute) { ...@@ -334,6 +343,27 @@ TEST(elementwise_add, compute) {
} }
} }
TEST(elementwise_add, compute_fp32) {
elementwise_add_compute<float, PRECISION(kFloat)>();
if (::testing::Test::HasFailure()) {
FAIL();
}
}
TEST(elementwise_add, compute_i32) {
elementwise_add_compute<int32_t, PRECISION(kInt32)>();
if (::testing::Test::HasFailure()) {
FAIL();
}
}
TEST(elementwise_add, compute_i64) {
elementwise_add_compute<int64_t, PRECISION(kInt64)>();
if (::testing::Test::HasFailure()) {
FAIL();
}
}
TEST(fusion_elementwise_add_activation_arm, retrive_op) { TEST(fusion_elementwise_add_activation_arm, retrive_op) {
auto fusion_elementwise_add_activation = auto fusion_elementwise_add_activation =
KernelRegistry::Global().Create("fusion_elementwise_add_activation"); KernelRegistry::Global().Create("fusion_elementwise_add_activation");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册