diff --git a/lite/backends/arm/math/gemm_prepacked_int8.cc b/lite/backends/arm/math/gemm_prepacked_int8.cc index 343e93439d2db563e5ccd4d8c6aed681601871a0..f0c7c65c9067dabb46ad43b3a20a1b85d86d62d0 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.cc +++ b/lite/backends/arm/math/gemm_prepacked_int8.cc @@ -2242,19 +2242,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, Dtype* tmp1 = nullptr; Dtype* tmp2 = nullptr; Dtype* tmp3 = nullptr; - float32_t scale_local[4]; + float32_t scale_local[4] = {0, 0, 0, 0}; float32_t bias_local[4] = {0, 0, 0, 0}; if (is_bias) { - bias_local[0] = bias[y]; - bias_local[1] = bias[y + 1]; - bias_local[2] = bias[y + 2]; - bias_local[3] = bias[y + 3]; + if (y + 4 <= M) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } else { + switch (M - y) { + case 3: + bias_local[2] = bias[y + 2]; + case 2: + bias_local[1] = bias[y + 1]; + case 1: + bias_local[0] = bias[y + 0]; + default: + break; + } + } } if (scale) { - scale_local[0] = scale[y]; - scale_local[1] = scale[y + 1]; - scale_local[2] = scale[y + 2]; - scale_local[3] = scale[y + 3]; + if (y + 4 <= M) { + scale_local[0] = scale[y]; + scale_local[1] = scale[y + 1]; + scale_local[2] = scale[y + 2]; + scale_local[3] = scale[y + 3]; + } else { + switch (M - y) { + case 3: + scale_local[2] = scale[y + 2]; + case 2: + scale_local[1] = scale[y + 1]; + case 1: + scale_local[0] = scale[y + 0]; + default: + break; + } + } } if (y + MBLOCK_INT8_OTH > M) { switch (y + MBLOCK_INT8_OTH - M) { diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 0ab86b0f0471efe2017df32d6bea1daae8c589b9..218ee3f053fcf49f6a08ffbe0d780509f9b2cc03 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -53,7 +53,7 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) - +add_kernel(group_norm_compute ARM extra SRCS group_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) ## 3. extra kernels add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(decode_bboxes_compute_arm ARM extra SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/deformable_conv_compute.cc b/lite/kernels/arm/deformable_conv_compute.cc index 6253b661d05535d7b3b4a2ee18de7707e80b2877..dfdd27799bc1df7f403f40cb50b48aebbfb8d67a 100644 --- a/lite/kernels/arm/deformable_conv_compute.cc +++ b/lite/kernels/arm/deformable_conv_compute.cc @@ -235,7 +235,8 @@ typedef paddle::lite::kernels::arm::DeformableConvCompute DeformableConvFp32; -REGISTER_LITE_KERNEL(deformconv2d, kARM, kFloat, kNCHW, DeformableConvFp32, def) +REGISTER_LITE_KERNEL( + deformable_conv, kARM, kFloat, kNCHW, DeformableConvFp32, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) diff --git a/lite/kernels/arm/group_norm_compute.cc b/lite/kernels/arm/group_norm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e370414f4079f8dbbc2e5cc9af294c7b3f88718 --- /dev/null +++ b/lite/kernels/arm/group_norm_compute.cc @@ -0,0 +1,180 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/group_norm_compute.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void GroupNormCompute::PrepareForRun() {} + +void GroupNormCompute::Run() { + auto& param = this->Param(); + const float* in = param.x->data(); + const float* scale = param.scale->data(); + const float* bias = param.bias->data(); + float* out = param.out->mutable_data(); + float* saved_mean = param.saved_mean->mutable_data(); + float* saved_variance = param.saved_variance->mutable_data(); + float epsilon = param.epsilon; + int groups = param.groups; + int channels = param.channels; + int n = param.x->dims()[0]; + int c = param.x->dims()[1]; + int ch_per_group = channels / groups; + int height = param.x->dims()[2]; + int width = param.x->dims()[3]; + int spatial_size = ch_per_group * height * width; + int ngroup = n * groups; + int cnt = spatial_size >> 4; + int remain = spatial_size % 16; +// compute saved_mean and saved_variance +#pragma omp parallel for + for (int n = 0; n < ngroup; ++n) { + const float* in_p = in + n * spatial_size; + float sum_spatial = 0.f; + float summ_spatial = 0.f; + float32x4_t sum0 = vdupq_n_f32(0.f); + float32x4_t sum1 = vdupq_n_f32(0.f); + float32x4_t sum2 = vdupq_n_f32(0.f); + float32x4_t sum3 = vdupq_n_f32(0.f); + float32x4_t summ0 = vdupq_n_f32(0.f); + float32x4_t summ1 = vdupq_n_f32(0.f); + float32x4_t summ2 = vdupq_n_f32(0.f); + float32x4_t summ3 = vdupq_n_f32(0.f); + for (int i = 0; i < cnt; i++) { + float32x4_t in0 = vld1q_f32(in_p); + float32x4_t in1 = vld1q_f32(in_p + 4); + float32x4_t in2 = vld1q_f32(in_p + 8); + float32x4_t in3 = vld1q_f32(in_p + 12); + sum0 = vaddq_f32(sum0, in0); + summ0 = vmlaq_f32(summ0, in0, in0); + sum1 = vaddq_f32(sum1, in1); + summ1 = vmlaq_f32(summ1, in1, in1); + sum2 = vaddq_f32(sum2, in2); + summ2 = vmlaq_f32(summ2, in2, in2); + sum3 = vaddq_f32(sum3, in3); + summ3 = vmlaq_f32(summ3, in3, in3); + in_p += 16; + } + for (int i = 0; i < remain - 3; i += 4) { + float32x4_t in0 = vld1q_f32(in_p); + sum1 = vaddq_f32(sum1, in0); + summ1 = vmlaq_f32(summ1, in0, in0); + in_p += 4; + } + float sum = 0.0; + float summ = 0.0; + sum0 = vaddq_f32(sum0, sum1); + sum2 = vaddq_f32(sum2, sum3); + summ0 = vaddq_f32(summ0, summ1); + summ2 = vaddq_f32(summ2, summ3); + for (int i = 0; i < remain % 4; i++) { + sum += *in_p; + summ += (*in_p) * (*in_p); + in_p++; + } + sum0 = vaddq_f32(sum0, sum2); + summ0 = vaddq_f32(summ0, summ2); + float32x2_t sum_low = vpadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + float32x2_t sum_high = vpadd_f32(vget_low_f32(summ0), vget_high_f32(summ0)); + float32x2_t sum_mix = vpadd_f32(sum_low, sum_high); + sum += vget_lane_f32(sum_mix, 0); + summ += vget_lane_f32(sum_mix, 1); + float mean = sum / spatial_size; + // float variance = summ / spatial_size - mean * mean; + // the flolowing code has higher precision than above comment code + float variance = (summ - mean * mean * spatial_size) / spatial_size; + float std = 1.f / sqrtf(variance + epsilon); + saved_mean[n] = mean; + saved_variance[n] = std; + } + int in_size = height * width; + cnt = in_size >> 4; + remain = in_size % 16; +// compute Group_norm result: out = scale * (in - mean) / std + bias +#pragma omp parallel for + for (int i = 0; i < ngroup; ++i) { + const float* in_p = in + i * spatial_size; + float* out_p = out + i * spatial_size; + int numc = i % groups; + numc *= ch_per_group; + for (int c = 0; c < ch_per_group; c++) { + int chin = numc + c; + const float sstd_val = scale[chin] * saved_variance[i]; + const float bias_val = bias[chin]; + const float mean_val = saved_mean[i]; + const float32x4_t vsstd = vdupq_n_f32(sstd_val); + const float32x4_t vbias = vdupq_n_f32(bias_val); + const float32x4_t vmean = vdupq_n_f32(mean_val); + for (int k = 0; k < cnt; k++) { + float32x4_t in0 = vld1q_f32(in_p); + float32x4_t in1 = vld1q_f32(in_p + 4); + float32x4_t in2 = vld1q_f32(in_p + 8); + float32x4_t in3 = vld1q_f32(in_p + 12); + float32x4_t submean0 = vsubq_f32(in0, vmean); + float32x4_t submean1 = vsubq_f32(in1, vmean); + float32x4_t submean2 = vsubq_f32(in2, vmean); + float32x4_t submean3 = vsubq_f32(in3, vmean); + float32x4_t out0 = vmlaq_f32(vbias, submean0, vsstd); + float32x4_t out1 = vmlaq_f32(vbias, submean1, vsstd); + float32x4_t out2 = vmlaq_f32(vbias, submean2, vsstd); + float32x4_t out3 = vmlaq_f32(vbias, submean3, vsstd); + vst1q_f32(out_p, out0); + vst1q_f32(out_p + 4, out1); + vst1q_f32(out_p + 8, out2); + vst1q_f32(out_p + 12, out3); + in_p += 16; + out_p += 16; + } + for (int k = 0; k < remain - 3; k += 4) { + float32x4_t in0 = vld1q_f32(in_p); + in_p += 4; + float32x4_t submean0 = vsubq_f32(in0, vmean); + float32x4_t out0 = vmlaq_f32(vbias, submean0, vsstd); + vst1q_f32(out_p, out0); + out_p += 4; + } + for (int k = 0; k < remain % 4; k++) { + *out_p = (*in_p - mean_val) * sstd_val + bias_val; + in_p++; + out_p++; + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(group_norm, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::GroupNormCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/group_norm_compute.h b/lite/kernels/arm/group_norm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..7d61b8ec8d9a1c8620c54858487b21691bef84d5 --- /dev/null +++ b/lite/kernels/arm/group_norm_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class GroupNormCompute : public KernelLite { + public: + using param_t = operators::GroupNormParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~GroupNormCompute() = default; + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index b62795ab4ae48f8f03d24995600ad3aee720b1ca..192cffccb19040a5ab77feae4d8b6a5a5fe4ba00 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -57,6 +57,7 @@ add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS}) add_operator(crop_op extra SRCS crop_op.cc DEPS ${op_DEPS}) add_operator(assign_op extra SRCS assign_op.cc DEPS ${op_DEPS}) add_operator(power_op extra SRCS power_op.cc DEPS ${op_DEPS}) +add_operator(group_norm_op extra SRCS group_norm_op.cc DEPS ${op_DEPS}) add_operator(norm_op extra SRCS norm_op.cc DEPS ${op_DEPS}) # 3.extra ops @@ -143,7 +144,7 @@ add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_o add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) # for deformable-convNet -add_operator(deformable_conv_op basic SRCS deformable_conv_op.cc DEPS ${op_DEPS}) +add_operator(deformable_conv_op extra SRCS deformable_conv_op.cc DEPS ${op_DEPS}) # 4. training op add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/deformable_conv_op.cc b/lite/operators/deformable_conv_op.cc index 8cc8614d00801fb033bc3f449e82f9f03e271db5..a834528f27c9d6c97e355a1a149482ad00ae79aa 100644 --- a/lite/operators/deformable_conv_op.cc +++ b/lite/operators/deformable_conv_op.cc @@ -84,5 +84,5 @@ bool DeformableConvOpLite::InferShapeImpl() const { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(DeformableConv2d, +REGISTER_LITE_OP(deformable_conv, paddle::lite::operators::DeformableConvOpLite); diff --git a/lite/operators/group_norm_op.cc b/lite/operators/group_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1a6413ebb140bac4a1d7e74ef42413f489395c7 --- /dev/null +++ b/lite/operators/group_norm_op.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/group_norm_op.h" +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool GroupNormOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.scale); + CHECK_OR_FALSE(param_.bias); + CHECK_OR_FALSE(param_.out); + CHECK_OR_FALSE(param_.saved_mean); + CHECK_OR_FALSE(param_.saved_variance); + auto x_dims = param_.x->dims(); + auto scale_dims = param_.scale->dims(); + auto bias_dims = param_.bias->dims(); + CHECK(x_dims.size() >= 2 && x_dims.size() <= 5) + << "Input X must have 2 to 5 dimensions."; + CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions."; + CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions."; + CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f"; + CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f"; + CHECK_EQ(param_.channels, x_dims[1]) + << "Input channels must be equal input_shape[1]"; + CHECK_EQ(param_.channels % param_.groups, 0) + << "channels must be divide groups"; + return true; +} + +bool GroupNormOp::InferShapeImpl() const { + auto x_dims = param_.x->dims(); + int64_t batch_size = x_dims[0]; + int64_t num = param_.channels / param_.groups; + param_.saved_mean->Resize({batch_size * num}); + param_.saved_variance->Resize({batch_size * num}); + param_.out->Resize(x_dims); + return true; +} + +bool GroupNormOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { + param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable(); + param_.scale = + scope->FindVar(op_desc.Input("Scale").front())->GetMutable(); + param_.bias = + scope->FindVar(op_desc.Input("Bias").front())->GetMutable(); + param_.saved_mean = + scope->FindVar(op_desc.Output("SavedMean").front())->GetMutable(); + param_.saved_variance = + scope->FindVar(op_desc.Output("SavedVariance").front()) + ->GetMutable(); + param_.out = + scope->FindVar(op_desc.Output("Y").front())->GetMutable(); + param_.epsilon = op_desc.GetAttr("epsilon"); + param_.groups = op_desc.GetAttr("groups"); + param_.channels = op_desc.GetAttr("channels"); + return true; +} + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ + +REGISTER_LITE_OP(group_norm, paddle::lite::operators::GroupNormOp); diff --git a/lite/operators/group_norm_op.h b/lite/operators/group_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f2251686ea2caa89e3934e8adae69466f9c9515d --- /dev/null +++ b/lite/operators/group_norm_op.h @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class GroupNormOp : public OpLite { + public: + GroupNormOp() {} + + explicit GroupNormOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "group_norm"; } + +#ifdef LITE_WITH_PROFILE + void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) { + ch->input_shape = ch->DimToStr(param_.x->dims()); + ch->output_shape = ch->DimToStr(param_.out->dims()); + // ch->remark = ""; + auto x_dims = param_.x->dims(); + auto nc = x_dims[0] * x_dims[1]; + auto hw = x_dims[2] * x_dims[3]; + auto nchw = x_dims.production(); + ch->macs = 5.f * nchw + 3.f * (nc + hw); + } +#endif + + private: + mutable GroupNormParam param_; +}; + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 58394ebdf9f93c11e5d60482e6e5a23449a2ad90..5594fc1590a241f5b120d12f7beee22368e3b958 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1429,6 +1429,19 @@ struct InstanceNormParam : ParamBase { lite::Tensor* saved_variance{}; float epsilon; }; +/// --------------------- group_norm operators -------------------- +struct GroupNormParam : ParamBase { + lite::Tensor* x{}; + lite::Tensor* out{}; + lite::Tensor* bias{}; + lite::Tensor* scale{}; + lite::Tensor* saved_mean{}; + lite::Tensor* saved_variance{}; + float epsilon; + int groups; + int channels; +}; + /// --------------------- grid sampler operators -------------------- struct GridSamplerParam : ParamBase { lite::Tensor* x{}; diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 03f0de291e80d821af5704727dbd30b10d2ca453..d29f88f334754720b4681042ac5693723e028ba1 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -17,6 +17,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_group_norm_compute SRCS group_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/group_norm_compute_test.cc b/lite/tests/kernels/group_norm_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1df003850731eb4d355d01f65100d2b9d200224 --- /dev/null +++ b/lite/tests/kernels/group_norm_compute_test.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +class GroupNormComputeTest : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "x"; + std::string y_ = "y"; + std::string saved_mean_ = "saved_mean"; + std::string saved_variance_ = "saved_variance"; + std::string scale_ = "scale"; + std::string bias_ = "bias"; + + DDim dims_{{4, 5, 19, 19}}; + float epsilon_ = 1e-5f; + int groups_ = 1; + int channels_ = dims_[1]; + + public: + GroupNormComputeTest(const Place& place, + const std::string& alias, + DDim dims, + float epsilon, + int groups, + int channels) + : TestCase(place, alias), + dims_(dims), + epsilon_(epsilon), + groups_(groups), + channels_(channels) {} + + void RunBaseline(Scope* scope) override { + auto x = scope->FindTensor(x_); + auto scale = scope->FindTensor(scale_); + auto bias = scope->FindTensor(bias_); + auto y = scope->NewTensor(y_); + auto saved_mean = scope->NewTensor(saved_mean_); + auto saved_variance = scope->NewTensor(saved_variance_); + CHECK(y); + CHECK(saved_mean); + CHECK(saved_variance); + DDim saved_dim({dims_[0] * groups_}); + y->Resize(dims_); + saved_mean->Resize(saved_dim); + saved_variance->Resize(saved_dim); + + auto x_data = x->data(); + auto scale_data = scale->data(); + auto bias_data = bias->data(); + auto y_data = y->mutable_data(); + auto saved_mean_data = saved_mean->mutable_data(); + auto saved_variance_data = saved_variance->mutable_data(); + + int n = x->dims()[0]; + int ch_per_group = channels_ / groups_; + CHECK_EQ(x->dims()[1], channels_); + int spatial_size = ch_per_group * x->dims()[2] * x->dims()[3]; + // compute mean + for (int i = 0; i < n * groups_; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float sum = 0.f; + for (int j = 0; j < spatial_size; ++j) { + sum += x_ptr[j]; + } + saved_mean_data[i] = sum / spatial_size; + } + // compute variance + for (int i = 0; i < n * groups_; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float sum = 0.f; + for (int j = 0; j < spatial_size; ++j) { + sum += + (x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]); + } + saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_); + } + int in_size = x->dims()[2] * x->dims()[3]; + // compute out + for (int i = 0; i < n * groups_; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float* y_ptr = y_data + i * spatial_size; + int c_num = i % groups_; + for (int c = 0; c < ch_per_group; c++) { + int chin = c_num * ch_per_group + c; + float scale_val = scale_data[chin]; + float bias_val = bias_data[chin]; + const float* x_ch_ptr = x_ptr + c * in_size; + float* y_ch_ptr = y_ptr + c * in_size; + for (int j = 0; j < in_size; j++) { + y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) * + saved_variance_data[i] + + bias_val; + } + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("group_norm"); + op_desc->SetInput("X", {x_}); + op_desc->SetInput("Bias", {bias_}); + op_desc->SetInput("Scale", {scale_}); + op_desc->SetOutput("Y", {y_}); + op_desc->SetOutput("SavedMean", {saved_mean_}); + op_desc->SetOutput("SavedVariance", {saved_variance_}); + op_desc->SetAttr("epsilon", epsilon_); + op_desc->SetAttr("groups", groups_); + op_desc->SetAttr("channels", channels_); + } + + void PrepareData() override { + std::vector x(dims_.production()); + fill_data_rand(x.data(), -1.f, 1.f, dims_.production()); + + DDim scale_bias_dims{{dims_[1]}}; + std::vector scale(scale_bias_dims.production()); + fill_data_rand(scale.data(), -1.f, 1.f, scale_bias_dims.production()); + std::vector bias(scale_bias_dims.production()); + fill_data_rand(bias.data(), -1.f, 1.f, scale_bias_dims.production()); + + SetCommonTensor(x_, dims_, x.data()); + SetCommonTensor(scale_, scale_bias_dims, scale.data(), {}, true); + SetCommonTensor(bias_, scale_bias_dims, bias.data(), {}, true); + } +}; + +void TestGroupNorm(Place place, + float abs_error = 6e-5, + std::vector ignored_outs = {}) { + for (auto& n : {1, 3, 16}) { + for (auto& c : {1}) { + for (auto& h : {1, 16, 33, 56}) { + for (auto& w : {1, 17, 55}) { + for (auto& groups : {1, 2, 4}) { + if (c % groups != 0) { + continue; + } + DDim dim_in({n, c, h, w}); + float epsilon = 1e-5f; + std::unique_ptr tester(new GroupNormComputeTest( + place, "def", dim_in, epsilon, groups, c)); +#ifdef LITE_WITH_ARM + if (place == TARGET(kARM)) { + auto& ctx = tester->context()->As(); + ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 4); + } +#endif + arena::Arena arena(std::move(tester), place, abs_error); + if (!arena.TestPrecision(ignored_outs)) { + LOG(ERROR) << "run n: " << n << ", c: " << c << ", h: " << h + << ", w: " << w; + return; + } + } + } + } + } + } +} + +TEST(GroupNorm, precision) { + Place place; + float abs_error = 6e-5; + std::vector ignored_outs = {}; +#ifdef LITE_WITH_ARM + place = TARGET(kARM); +#else + return; +#endif + TestGroupNorm(place, abs_error, ignored_outs); +} +} // namespace lite +} // namespace paddle diff --git a/lite/tests/math/deformable_conv_compute_test.cc b/lite/tests/math/deformable_conv_compute_test.cc index d7a06c6a104ac3ac6db5d79aced6183e8bdf5963..76cb970ffe428ed393cdbdae0d281e6a511655ac 100644 --- a/lite/tests/math/deformable_conv_compute_test.cc +++ b/lite/tests/math/deformable_conv_compute_test.cc @@ -342,7 +342,7 @@ TEST(TestDeformableConvRand, test_deformable_conv_rand) { if (FLAGS_basic_test) { for (auto& cin : {1, 3, 8}) { for (auto& cout : {1, 5, 16}) { - for (auto& g : {1, 2}) { + for (auto& g : {1}) { for (auto& kw : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) { for (auto& stride : {1, 2}) {