提交 684c9197 编写于 作者: Z zhangwen31

[arm][kernel]test: add elementwise compute test for sub mul div test=develop

上级 266965e2
...@@ -106,6 +106,20 @@ void elementwise_compute_ref(const operators::ElementwiseParam& param, ...@@ -106,6 +106,20 @@ void elementwise_compute_ref(const operators::ElementwiseParam& param,
} }
} }
} }
} else if (elt_type == "div") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr / diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "max") { } else if (elt_type == "max") {
for (int i = 0; i < batch; ++i) { for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) { for (int j = 0; j < channels; ++j) {
...@@ -254,11 +268,14 @@ template void elementwise_imod_compute_ref<int32_t>( ...@@ -254,11 +268,14 @@ template void elementwise_imod_compute_ref<int32_t>(
template void elementwise_imod_compute_ref<int64_t>( template void elementwise_imod_compute_ref<int64_t>(
const operators::ElementwiseParam& param, const std::string act_type); const operators::ElementwiseParam& param, const std::string act_type);
template <typename T, PrecisionType PType> template <template <class, PrecisionType> class ElementWiseComputeTemplate,
void elementwise_add_compute() { typename T,
ElementwiseAddCompute<T, PType> elementwise_add; PrecisionType PType>
void do_elementwise_compute(const char* op_type_str) {
ElementWiseComputeTemplate<T, PType> elementwise_add;
operators::ElementwiseParam param; operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref; lite::Tensor x, y, output, output_ref;
unsigned int rand_seed = 1;
#if 1 #if 1
for (auto n : {1, 3, 4}) { for (auto n : {1, 3, 4}) {
...@@ -311,10 +328,12 @@ void elementwise_add_compute() { ...@@ -311,10 +328,12 @@ void elementwise_add_compute() {
T* output_data = output.mutable_data<T>(); T* output_data = output.mutable_data<T>();
T* output_ref_data = output_ref.mutable_data<T>(); T* output_ref_data = output_ref.mutable_data<T>();
for (int i = 0; i < x_dim.production(); i++) { for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i; x_data[i] = 1.0 * rand_r(&rand_seed) * rand_r(&rand_seed) /
(rand_r(&rand_seed) + 1);
} }
for (int i = 0; i < y_dim.production(); i++) { for (int i = 0; i < y_dim.production(); i++) {
y_data[i] = i; y_data[i] = 1.0 * rand_r(&rand_seed) * rand_r(&rand_seed) /
(rand_r(&rand_seed) + 1);
} }
param.X = &x; param.X = &x;
param.Y = &y; param.Y = &y;
...@@ -323,15 +342,15 @@ void elementwise_add_compute() { ...@@ -323,15 +342,15 @@ void elementwise_add_compute() {
elementwise_add.SetParam(param); elementwise_add.SetParam(param);
elementwise_add.Run(); elementwise_add.Run();
param.Out = &output_ref; param.Out = &output_ref;
elementwise_compute_ref<T>(param, "add", ""); elementwise_compute_ref<T>(param, op_type_str, "");
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5) ASSERT_NEAR(output_data[i], output_ref_data[i], 1e-5)
<< "Value differ at index " << i; << "Value differ at index " << i;
} }
} else { } else {
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_EQ(output_data[i], output_ref_data[i]) ASSERT_EQ(output_data[i], output_ref_data[i])
<< "Value differ at index " << i; << "Value differ at index " << i;
} }
} }
...@@ -344,21 +363,42 @@ void elementwise_add_compute() { ...@@ -344,21 +363,42 @@ void elementwise_add_compute() {
} }
TEST(elementwise_add, compute_fp32) { TEST(elementwise_add, compute_fp32) {
elementwise_add_compute<float, PRECISION(kFloat)>(); do_elementwise_compute<ElementwiseAddCompute, float, PRECISION(kFloat)>(
"add");
do_elementwise_compute<ElementwiseSubCompute, float, PRECISION(kFloat)>(
"sub");
do_elementwise_compute<ElementwiseMulCompute, float, PRECISION(kFloat)>(
"mul");
do_elementwise_compute<ElementwiseDivCompute, float, PRECISION(kFloat)>(
"div");
if (::testing::Test::HasFailure()) { if (::testing::Test::HasFailure()) {
FAIL(); FAIL();
} }
} }
TEST(elementwise_add, compute_i32) { TEST(elementwise_add, compute_i32) {
elementwise_add_compute<int32_t, PRECISION(kInt32)>(); do_elementwise_compute<ElementwiseAddCompute, int32_t, PRECISION(kInt32)>(
"add");
do_elementwise_compute<ElementwiseSubCompute, int32_t, PRECISION(kInt32)>(
"sub");
do_elementwise_compute<ElementwiseMulCompute, int32_t, PRECISION(kInt32)>(
"mul");
do_elementwise_compute<ElementwiseDivCompute, int32_t, PRECISION(kInt32)>(
"div");
if (::testing::Test::HasFailure()) { if (::testing::Test::HasFailure()) {
FAIL(); FAIL();
} }
} }
TEST(elementwise_add, compute_i64) { TEST(elementwise_add, compute_i64) {
elementwise_add_compute<int64_t, PRECISION(kInt64)>(); do_elementwise_compute<ElementwiseAddCompute, int64_t, PRECISION(kInt64)>(
"add");
do_elementwise_compute<ElementwiseSubCompute, int64_t, PRECISION(kInt64)>(
"sub");
do_elementwise_compute<ElementwiseMulCompute, int64_t, PRECISION(kInt64)>(
"mul");
do_elementwise_compute<ElementwiseDivCompute, int64_t, PRECISION(kInt64)>(
"div");
if (::testing::Test::HasFailure()) { if (::testing::Test::HasFailure()) {
FAIL(); FAIL();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册