diff --git a/lite/kernels/arm/elementwise_compute_test.cc b/lite/kernels/arm/elementwise_compute_test.cc index bf454f10a874f6ad1d65887f5199e75f9afce284..cece331e25add2e3503700677390b7b96694f480 100644 --- a/lite/kernels/arm/elementwise_compute_test.cc +++ b/lite/kernels/arm/elementwise_compute_test.cc @@ -254,8 +254,9 @@ template void elementwise_imod_compute_ref( template void elementwise_imod_compute_ref( const operators::ElementwiseParam& param, const std::string act_type); -TEST(elementwise_add, compute) { - ElementwiseAddCompute elementwise_add; +template +void elementwise_add_compute() { + ElementwiseAddCompute elementwise_add; operators::ElementwiseParam param; lite::Tensor x, y, output, output_ref; @@ -305,10 +306,10 @@ TEST(elementwise_add, compute) { y.Resize(y_dim); output.Resize(x_dim); output_ref.Resize(x_dim); - auto* x_data = x.mutable_data(); - auto* y_data = y.mutable_data(); - auto* output_data = output.mutable_data(); - auto* output_ref_data = output_ref.mutable_data(); + T* x_data = x.mutable_data(); + T* y_data = y.mutable_data(); + T* output_data = output.mutable_data(); + T* output_ref_data = output_ref.mutable_data(); for (int i = 0; i < x_dim.production(); i++) { x_data[i] = i; } @@ -322,9 +323,17 @@ TEST(elementwise_add, compute) { elementwise_add.SetParam(param); elementwise_add.Run(); param.Out = &output_ref; - elementwise_compute_ref(param, "add", ""); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + elementwise_compute_ref(param, "add", ""); + if (std::is_floating_point::value) { + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5) + << "Value differ at index " << i; + } + } else { + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_EQ(output_data[i], output_ref_data[i]) + << "Value differ at index " << i; + } } } } @@ -334,6 +343,27 @@ TEST(elementwise_add, compute) { } } +TEST(elementwise_add, compute_fp32) { + elementwise_add_compute(); + if (::testing::Test::HasFailure()) { + FAIL(); + } +} + +TEST(elementwise_add, compute_i32) { + elementwise_add_compute(); + if (::testing::Test::HasFailure()) { + FAIL(); + } +} + +TEST(elementwise_add, compute_i64) { + elementwise_add_compute(); + if (::testing::Test::HasFailure()) { + FAIL(); + } +} + TEST(fusion_elementwise_add_activation_arm, retrive_op) { auto fusion_elementwise_add_activation = KernelRegistry::Global().Create("fusion_elementwise_add_activation");