提交 684c9197 编写于 作者: Z zhangwen31

[arm][kernel]test: add elementwise compute test for sub mul div test=develop

上级 266965e2
......@@ -106,6 +106,20 @@ void elementwise_compute_ref(const operators::ElementwiseParam& param,
}
}
}
} else if (elt_type == "div") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr / diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "max") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
......@@ -254,11 +268,14 @@ template void elementwise_imod_compute_ref<int32_t>(
template void elementwise_imod_compute_ref<int64_t>(
const operators::ElementwiseParam& param, const std::string act_type);
template <typename T, PrecisionType PType>
void elementwise_add_compute() {
ElementwiseAddCompute<T, PType> elementwise_add;
template <template <class, PrecisionType> class ElementWiseComputeTemplate,
typename T,
PrecisionType PType>
void do_elementwise_compute(const char* op_type_str) {
ElementWiseComputeTemplate<T, PType> elementwise_add;
operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref;
unsigned int rand_seed = 1;
#if 1
for (auto n : {1, 3, 4}) {
......@@ -311,10 +328,12 @@ void elementwise_add_compute() {
T* output_data = output.mutable_data<T>();
T* output_ref_data = output_ref.mutable_data<T>();
for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i;
x_data[i] = 1.0 * rand_r(&rand_seed) * rand_r(&rand_seed) /
(rand_r(&rand_seed) + 1);
}
for (int i = 0; i < y_dim.production(); i++) {
y_data[i] = i;
y_data[i] = 1.0 * rand_r(&rand_seed) * rand_r(&rand_seed) /
(rand_r(&rand_seed) + 1);
}
param.X = &x;
param.Y = &y;
......@@ -323,15 +342,15 @@ void elementwise_add_compute() {
elementwise_add.SetParam(param);
elementwise_add.Run();
param.Out = &output_ref;
elementwise_compute_ref<T>(param, "add", "");
elementwise_compute_ref<T>(param, op_type_str, "");
if (std::is_floating_point<T>::value) {
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5)
ASSERT_NEAR(output_data[i], output_ref_data[i], 1e-5)
<< "Value differ at index " << i;
}
} else {
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_EQ(output_data[i], output_ref_data[i])
ASSERT_EQ(output_data[i], output_ref_data[i])
<< "Value differ at index " << i;
}
}
......@@ -344,21 +363,42 @@ void elementwise_add_compute() {
}
TEST(elementwise_add, compute_fp32) {
elementwise_add_compute<float, PRECISION(kFloat)>();
do_elementwise_compute<ElementwiseAddCompute, float, PRECISION(kFloat)>(
"add");
do_elementwise_compute<ElementwiseSubCompute, float, PRECISION(kFloat)>(
"sub");
do_elementwise_compute<ElementwiseMulCompute, float, PRECISION(kFloat)>(
"mul");
do_elementwise_compute<ElementwiseDivCompute, float, PRECISION(kFloat)>(
"div");
if (::testing::Test::HasFailure()) {
FAIL();
}
}
TEST(elementwise_add, compute_i32) {
elementwise_add_compute<int32_t, PRECISION(kInt32)>();
do_elementwise_compute<ElementwiseAddCompute, int32_t, PRECISION(kInt32)>(
"add");
do_elementwise_compute<ElementwiseSubCompute, int32_t, PRECISION(kInt32)>(
"sub");
do_elementwise_compute<ElementwiseMulCompute, int32_t, PRECISION(kInt32)>(
"mul");
do_elementwise_compute<ElementwiseDivCompute, int32_t, PRECISION(kInt32)>(
"div");
if (::testing::Test::HasFailure()) {
FAIL();
}
}
TEST(elementwise_add, compute_i64) {
elementwise_add_compute<int64_t, PRECISION(kInt64)>();
do_elementwise_compute<ElementwiseAddCompute, int64_t, PRECISION(kInt64)>(
"add");
do_elementwise_compute<ElementwiseSubCompute, int64_t, PRECISION(kInt64)>(
"sub");
do_elementwise_compute<ElementwiseMulCompute, int64_t, PRECISION(kInt64)>(
"mul");
do_elementwise_compute<ElementwiseDivCompute, int64_t, PRECISION(kInt64)>(
"div");
if (::testing::Test::HasFailure()) {
FAIL();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册