From b059fb33f8621b2befe8c7b4c86e4c5da92552b8 Mon Sep 17 00:00:00 2001 From: zhangwen31 Date: Mon, 14 Sep 2020 05:22:08 +0000 Subject: [PATCH] [arm][kernel]refactor: elementwise_add's compute test uses template now test=develop add: i32 and i64 test case --- lite/kernels/arm/elementwise_compute_test.cc | 48 ++++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/lite/kernels/arm/elementwise_compute_test.cc b/lite/kernels/arm/elementwise_compute_test.cc index bf454f10a8..cece331e25 100644 --- a/lite/kernels/arm/elementwise_compute_test.cc +++ b/lite/kernels/arm/elementwise_compute_test.cc @@ -254,8 +254,9 @@ template void elementwise_imod_compute_ref( template void elementwise_imod_compute_ref( const operators::ElementwiseParam& param, const std::string act_type); -TEST(elementwise_add, compute) { - ElementwiseAddCompute elementwise_add; +template +void elementwise_add_compute() { + ElementwiseAddCompute elementwise_add; operators::ElementwiseParam param; lite::Tensor x, y, output, output_ref; @@ -305,10 +306,10 @@ TEST(elementwise_add, compute) { y.Resize(y_dim); output.Resize(x_dim); output_ref.Resize(x_dim); - auto* x_data = x.mutable_data(); - auto* y_data = y.mutable_data(); - auto* output_data = output.mutable_data(); - auto* output_ref_data = output_ref.mutable_data(); + T* x_data = x.mutable_data(); + T* y_data = y.mutable_data(); + T* output_data = output.mutable_data(); + T* output_ref_data = output_ref.mutable_data(); for (int i = 0; i < x_dim.production(); i++) { x_data[i] = i; } @@ -322,9 +323,17 @@ TEST(elementwise_add, compute) { elementwise_add.SetParam(param); elementwise_add.Run(); param.Out = &output_ref; - elementwise_compute_ref(param, "add", ""); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + elementwise_compute_ref(param, "add", ""); + if (std::is_floating_point::value) { + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5) + << "Value differ at index " << i; + } + } else { + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_EQ(output_data[i], output_ref_data[i]) + << "Value differ at index " << i; + } } } } @@ -334,6 +343,27 @@ TEST(elementwise_add, compute) { } } +TEST(elementwise_add, compute_fp32) { + elementwise_add_compute(); + if (::testing::Test::HasFailure()) { + FAIL(); + } +} + +TEST(elementwise_add, compute_i32) { + elementwise_add_compute(); + if (::testing::Test::HasFailure()) { + FAIL(); + } +} + +TEST(elementwise_add, compute_i64) { + elementwise_add_compute(); + if (::testing::Test::HasFailure()) { + FAIL(); + } +} + TEST(fusion_elementwise_add_activation_arm, retrive_op) { auto fusion_elementwise_add_activation = KernelRegistry::Global().Create("fusion_elementwise_add_activation"); -- GitLab