From 9a3552dbb86c1874a4dc3b4b85070a4aec97f185 Mon Sep 17 00:00:00 2001 From: yiicy Date: Tue, 10 Dec 2019 13:35:53 +0800 Subject: [PATCH] [ARM] add instance norm op and ut, test=develop (#2578) --- lite/backends/arm/math/interpolate.cc | 39 ++-- lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/instance_norm_compute.cc | 179 ++++++++++++++++++ lite/kernels/arm/instance_norm_compute.h | 40 ++++ lite/operators/CMakeLists.txt | 1 + lite/operators/instance_norm_op.cc | 77 ++++++++ lite/operators/instance_norm_op.h | 47 +++++ lite/operators/op_params.h | 11 ++ lite/tests/kernels/CMakeLists.txt | 1 + .../kernels/instance_norm_compute_test.cc | 164 ++++++++++++++++ 10 files changed, 546 insertions(+), 14 deletions(-) create mode 100644 lite/kernels/arm/instance_norm_compute.cc create mode 100644 lite/kernels/arm/instance_norm_compute.h create mode 100644 lite/operators/instance_norm_op.cc create mode 100644 lite/operators/instance_norm_op.h create mode 100644 lite/tests/kernels/instance_norm_compute_test.cc diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index e9e18043df..34d9a20433 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -477,17 +477,23 @@ void nearest_interp(const float* src, float scale_h_new = (with_align) ? (static_cast(h_in - 1) / (h_out - 1)) : (static_cast(h_in) / (h_out)); - -#pragma omp parallel for collapse(2) schedule(static) - for (int h = 0; h < h_out; ++h) { - for (int w = 0; w < w_out; ++w) { - int near_x = (with_align) ? static_cast(scale_w_new * w + 0.5) - : static_cast(scale_w_new * w); - int near_y = (with_align) ? static_cast(scale_h_new * h + 0.5) - : static_cast(scale_h_new * h); - near_x = near_x < 0 ? 0 : near_x; - near_y = near_y < 0 ? 0 : near_y; - dst[h * w_out + w] = src[near_y * w_in + near_x]; + if (with_align) { + for (int h = 0; h < h_out; ++h) { + float* dst_p = dst + h * w_out; + int near_y = static_cast(scale_h_new * h + 0.5); + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w + 0.5); + *dst_p++ = src[near_y * w_in + near_x]; + } + } + } else { + for (int h = 0; h < h_out; ++h) { + float* dst_p = dst + h * w_out; + int near_y = static_cast(scale_h_new * h); + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w); + *dst_p++ = src[near_y * w_in + near_x]; + } } } } @@ -544,8 +550,10 @@ void interpolate(lite::Tensor* X, int out_w = Out->dims()[3]; int spatial_in = in_h * in_w; int spatial_out = out_h * out_w; - for (int i = 0; i < count; ++i) { - if ("Bilinear" == interpolate_type) { + + if ("Bilinear" == interpolate_type) { +#pragma omp parallel for + for (int i = 0; i < count; ++i) { bilinear_interp(din + spatial_in * i, in_w, in_h, @@ -555,7 +563,10 @@ void interpolate(lite::Tensor* X, 1.f / width_scale, 1.f / height_scale, with_align); - } else if ("Nearest" == interpolate_type) { + } + } else if ("Nearest" == interpolate_type) { +#pragma omp parallel for + for (int i = 0; i < count; ++i) { nearest_interp(din + spatial_in * i, in_w, in_h, diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 49c220f0c9..6f92983bc3 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -41,6 +41,7 @@ add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc D add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) ## 2.other basic kernels: basic kernels that not used in basic models add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/instance_norm_compute.cc b/lite/kernels/arm/instance_norm_compute.cc new file mode 100644 index 0000000000..e3e82c53ac --- /dev/null +++ b/lite/kernels/arm/instance_norm_compute.cc @@ -0,0 +1,179 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/instance_norm_compute.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void InstanceNormCompute::PrepareForRun() {} + +void InstanceNormCompute::Run() { + auto& param = this->Param(); + const float* in = param.x->data(); + const float* scale = param.scale->data(); + const float* bias = param.bias->data(); + float* out = param.out->mutable_data(); + float* saved_mean = param.saved_mean->mutable_data(); + float* saved_variance = param.saved_variance->mutable_data(); + float epsilon = param.epsilon; + + int n = param.x->dims()[0]; + int c = param.x->dims()[1]; + int nc = n * c; + int height = param.x->dims()[2]; + int width = param.x->dims()[3]; + int spatial_size = height * width; +// compute saved_mean and saved_variance +#pragma omp parallel for + for (int i = 0; i < nc; ++i) { + const float* in_p = in + i * spatial_size; + float sum_spatial = 0.f; + float summ_spatial = 0.f; + for (int h = 0; h < height; ++h) { + int w = width; + float32x4_t sum0 = vdupq_n_f32(0.f); + float32x4_t sum1 = vdupq_n_f32(0.f); + float32x4_t sum2 = vdupq_n_f32(0.f); + float32x4_t sum3 = vdupq_n_f32(0.f); + float32x4_t summ0 = vdupq_n_f32(0.f); + float32x4_t summ1 = vdupq_n_f32(0.f); + float32x4_t summ2 = vdupq_n_f32(0.f); + float32x4_t summ3 = vdupq_n_f32(0.f); + float32x4_t in0, in1, in2, in3; + for (; w > 15; w -= 16) { + in0 = vld1q_f32(in_p); + in1 = vld1q_f32(in_p + 4); + in2 = vld1q_f32(in_p + 8); + in3 = vld1q_f32(in_p + 12); + sum0 = vaddq_f32(sum0, in0); + sum1 = vaddq_f32(sum1, in1); + summ0 = vmlaq_f32(summ0, in0, in0); + summ1 = vmlaq_f32(summ1, in1, in1); + sum2 = vaddq_f32(sum2, in2); + sum3 = vaddq_f32(sum3, in3); + summ2 = vmlaq_f32(summ2, in2, in2); + summ3 = vmlaq_f32(summ3, in3, in3); + in_p += 16; + } + for (; w > 7; w -= 8) { + in0 = vld1q_f32(in_p); + in1 = vld1q_f32(in_p + 4); + sum0 = vaddq_f32(sum0, in0); + sum1 = vaddq_f32(sum1, in1); + summ0 = vmlaq_f32(summ0, in0, in0); + summ1 = vmlaq_f32(summ1, in1, in1); + in_p += 8; + } + for (; w > 3; w -= 4) { + in0 = vld1q_f32(in_p); + sum0 = vaddq_f32(sum0, in0); + summ0 = vmlaq_f32(summ0, in0, in0); + in_p += 4; + } + float sum = 0.f; + float summ = 0.f; + for (; w > 0; w--) { + sum += *in_p; + summ += (*in_p) * (*in_p); + in_p++; + } + sum0 = vaddq_f32(sum0, sum1); + sum2 = vaddq_f32(sum2, sum3); + summ0 = vaddq_f32(summ0, summ1); + summ2 = vaddq_f32(summ2, summ3); + sum0 = vaddq_f32(sum0, sum2); + summ0 = vaddq_f32(summ0, summ2); + float32x2_t sum_low = vpadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + float32x2_t sum_high = + vpadd_f32(vget_low_f32(summ0), vget_high_f32(summ0)); + float32x2_t sum_mix = vpadd_f32(sum_low, sum_high); + sum += vget_lane_f32(sum_mix, 0); + summ += vget_lane_f32(sum_mix, 1); + sum_spatial += sum; + summ_spatial += summ; + } + float mean = sum_spatial / spatial_size; + // float variance = summ / spatial_size - mean * mean; + // the flolowing code has higher precision than above comment code + float variance = (summ_spatial - mean * mean * spatial_size) / spatial_size; + float std = 1.f / sqrtf(variance + epsilon); + + saved_mean[i] = mean; + saved_variance[i] = std; + } +// compute instance_norm result: out = scale * (in - mean) / std + bias +#pragma omp parallel for + for (int i = 0; i < nc; ++i) { + const float* in_p = in + i * spatial_size; + float* out_p = out + i * spatial_size; + int j = spatial_size; + const float sstd_val = scale[i % c] * saved_variance[i]; + const float bias_val = bias[i % c]; + const float mean_val = saved_mean[i]; + const float32x4_t vsstd = vdupq_n_f32(sstd_val); + const float32x4_t vbias = vdupq_n_f32(bias_val); + const float32x4_t vmean = vdupq_n_f32(mean_val); + float32x4_t in0, in1, submean0, submean1, out0, out1; + for (; j > 7; j -= 8) { + in0 = vld1q_f32(in_p); + in1 = vld1q_f32(in_p + 4); + submean0 = vsubq_f32(in0, vmean); + submean1 = vsubq_f32(in1, vmean); + out0 = vmlaq_f32(vbias, submean0, vsstd); + out1 = vmlaq_f32(vbias, submean1, vsstd); + vst1q_f32(out_p, out0); + vst1q_f32(out_p + 4, out1); + in_p += 8; + out_p += 8; + } + for (; j > 3; j -= 4) { + in0 = vld1q_f32(in_p); + submean0 = vsubq_f32(in0, vmean); + out0 = vmlaq_f32(vbias, submean0, vsstd); + vst1q_f32(out_p, out0); + in_p += 4; + out_p += 4; + } + for (; j > 0; j--) { + *out_p = (*in_p - mean_val) * sstd_val + bias_val; + in_p++; + out_p++; + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(instance_norm, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::InstanceNormCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/instance_norm_compute.h b/lite/kernels/arm/instance_norm_compute.h new file mode 100644 index 0000000000..3fc056a372 --- /dev/null +++ b/lite/kernels/arm/instance_norm_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class InstanceNormCompute : public KernelLite { + public: + using param_t = operators::InstanceNormParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~InstanceNormCompute() = default; + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 59f7e7124d..1cbd67112f 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -47,6 +47,7 @@ add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_act add_operator(io_copy_once_op basic SRCS io_copy_once_op.cc DEPS io_copy_op ${op_DEPS}) add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS}) add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS}) +add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS}) add_operator(graph_op basic SRCS graph_op.cc DEPS ${op_DEPS}) # 2.basic ops not used in basic models diff --git a/lite/operators/instance_norm_op.cc b/lite/operators/instance_norm_op.cc new file mode 100644 index 0000000000..510402ba1f --- /dev/null +++ b/lite/operators/instance_norm_op.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/instance_norm_op.h" +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool InstanceNormOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.scale); + CHECK_OR_FALSE(param_.bias); + CHECK_OR_FALSE(param_.out); + CHECK_OR_FALSE(param_.saved_mean); + CHECK_OR_FALSE(param_.saved_variance); + auto x_dims = param_.x->dims(); + auto scale_dims = param_.scale->dims(); + auto bias_dims = param_.bias->dims(); + CHECK(x_dims.size() >= 2 && x_dims.size() <= 5) + << "Input X must have 2 to 5 dimensions."; + CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions."; + CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions."; + CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f"; + CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f"; + return true; +} + +bool InstanceNormOp::InferShape() const { + auto x_dims = param_.x->dims(); + int64_t batch_size = x_dims[0]; + int64_t channel_size = x_dims[1]; + param_.saved_mean->Resize({batch_size * channel_size}); + param_.saved_variance->Resize({batch_size * channel_size}); + param_.out->Resize(x_dims); + return true; +} + +bool InstanceNormOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable(); + param_.scale = + scope->FindVar(op_desc.Input("Scale").front())->GetMutable(); + param_.bias = + scope->FindVar(op_desc.Input("Bias").front())->GetMutable(); + param_.saved_mean = + scope->FindVar(op_desc.Output("SavedMean").front())->GetMutable(); + param_.saved_variance = + scope->FindVar(op_desc.Output("SavedVariance").front()) + ->GetMutable(); + param_.out = + scope->FindVar(op_desc.Output("Y").front())->GetMutable(); + param_.epsilon = op_desc.GetAttr("epsilon"); + return true; +} + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ + +REGISTER_LITE_OP(instance_norm, paddle::lite::operators::InstanceNormOp); diff --git a/lite/operators/instance_norm_op.h b/lite/operators/instance_norm_op.h new file mode 100644 index 0000000000..d128345805 --- /dev/null +++ b/lite/operators/instance_norm_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class InstanceNormOp : public OpLite { + public: + InstanceNormOp() {} + + explicit InstanceNormOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "instance_norm"; } + + private: + mutable InstanceNormParam param_; +}; + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 3534998663..0bca9b3df4 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1090,6 +1090,17 @@ struct CollectFpnProposalsParam { int post_nms_topN{}; }; +/// --------------------- instance_norm operators -------------------- +struct InstanceNormParam { + lite::Tensor* x{}; + lite::Tensor* out{}; + lite::Tensor* bias{}; + lite::Tensor* scale{}; + lite::Tensor* saved_mean{}; + lite::Tensor* saved_variance{}; + float epsilon; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 712e7dd306..48004337e3 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -14,6 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/instance_norm_compute_test.cc b/lite/tests/kernels/instance_norm_compute_test.cc new file mode 100644 index 0000000000..d2f7e964ee --- /dev/null +++ b/lite/tests/kernels/instance_norm_compute_test.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +class InstanceNormComputeTest : public arena::TestCase { + protected: + // common attributes for this op. + std::string input_ = "x"; + std::string output_ = "y"; + std::string saved_mean_ = "saved_mean"; + std::string saved_variance_ = "saved_variance"; + std::string scale_ = "scale"; + std::string bias_ = "bias"; + + DDim dims_{{4, 5, 19, 19}}; + float epsilon_ = 1e-5f; + + public: + InstanceNormComputeTest(const Place& place, + const std::string& alias, + DDim dims, + float epsilon) + : TestCase(place, alias), dims_(dims), epsilon_(epsilon) {} + + void RunBaseline(Scope* scope) override { + auto x = scope->FindTensor(input_); + auto scale = scope->FindTensor(scale_); + auto bias = scope->FindTensor(bias_); + auto out = scope->NewTensor(output_); + auto saved_mean = scope->NewTensor(saved_mean_); + auto saved_variance = scope->NewTensor(saved_variance_); + CHECK(out); + CHECK(saved_mean); + CHECK(saved_variance); + DDim saved_dim({dims_[0] * dims_[1]}); + out->Resize(dims_); + saved_mean->Resize(saved_dim); + saved_variance->Resize(saved_dim); + + auto x_data = x->data(); + auto scale_data = scale->data(); + auto bias_data = bias->data(); + auto out_data = out->mutable_data(); + auto saved_mean_data = saved_mean->mutable_data(); + auto saved_variance_data = saved_variance->mutable_data(); + + int n = x->dims()[0]; + int c = x->dims()[1]; + int spatial_size = x->dims()[2] * x->dims()[3]; + + // compute mean + for (int i = 0; i < n * c; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float sum = 0.f; + for (int j = 0; j < spatial_size; ++j) { + sum += x_ptr[j]; + } + saved_mean_data[i] = sum / spatial_size; + } + // compute variance + for (int i = 0; i < n * c; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float sum = 0.f; + for (int j = 0; j < spatial_size; ++j) { + sum += + (x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]); + } + saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_); + } + // compute out + for (int i = 0; i < n * c; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float* out_ptr = out_data + i * spatial_size; + float scale_val = scale_data[i % c]; + float bias_val = bias_data[i % c]; + for (int j = 0; j < spatial_size; ++j) { + out_ptr[j] = scale_val * (x_ptr[j] - saved_mean_data[i]) * + saved_variance_data[i] + + bias_val; + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("instance_norm"); + op_desc->SetInput("X", {input_}); + op_desc->SetInput("Bias", {bias_}); + op_desc->SetInput("Scale", {scale_}); + op_desc->SetOutput("Y", {output_}); + op_desc->SetOutput("SavedMean", {saved_mean_}); + op_desc->SetOutput("SavedVariance", {saved_variance_}); + op_desc->SetAttr("epsilon", epsilon_); + } + + void PrepareData() override { + std::vector din(dims_.production()); + fill_data_rand(din.data(), -1.f, 1.f, dims_.production()); + + DDim scale_dim{{dims_[1]}}; + std::vector scale(scale_dim.production()); + fill_data_rand(scale.data(), -1.f, 1.f, scale_dim.production()); + + std::vector bias(scale_dim.production()); + fill_data_rand(bias.data(), -1.f, 1.f, scale_dim.production()); + + SetCommonTensor(input_, dims_, din.data()); + SetCommonTensor(scale_, scale_dim, scale.data()); + SetCommonTensor(bias_, scale_dim, bias.data()); + } +}; + +void test_instance_norm(Place place) { + for (auto& n : {1, 3, 16}) { + for (auto& c : {1, 4, 16}) { + for (auto& h : {1, 16, 33, 56}) { + for (auto& w : {1, 17, 34, 55}) { + DDim dim_in({n, c, h, w}); + float epsilon = 1e-5f; + std::unique_ptr tester( + new InstanceNormComputeTest(place, "def", dim_in, epsilon)); +#ifdef LITE_WITH_ARM + auto& ctx = tester->context()->As(); + ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 4); +#endif + arena::Arena arena(std::move(tester), place, 6e-5); + if (!arena.TestPrecision()) { + LOG(ERROR) << "run n: " << n << ", c: " << c << ", h: " << h + << ", w: " << w; + return; + } + } + } + } + } +} + +TEST(InstanceNorm, precision) { +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_instance_norm(place); +#endif +} + +} // namespace lite +} // namespace paddle -- GitLab