未验证 提交 c0a3b120 编写于 作者: H HappyAngel 提交者: GitHub

[Arm]Add Group_norm OP (#3781)


* add grouup_norm

* fix format. test=develop

* fix xiaodu crash. test=develop
上级 29771f27
......@@ -2242,19 +2242,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
Dtype* tmp1 = nullptr;
Dtype* tmp2 = nullptr;
Dtype* tmp3 = nullptr;
float32_t scale_local[4];
float32_t scale_local[4] = {0, 0, 0, 0};
float32_t bias_local[4] = {0, 0, 0, 0};
if (is_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
if (y + 4 <= M) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
} else {
switch (M - y) {
case 3:
bias_local[2] = bias[y + 2];
case 2:
bias_local[1] = bias[y + 1];
case 1:
bias_local[0] = bias[y + 0];
default:
break;
}
}
}
if (scale) {
scale_local[0] = scale[y];
scale_local[1] = scale[y + 1];
scale_local[2] = scale[y + 2];
scale_local[3] = scale[y + 3];
if (y + 4 <= M) {
scale_local[0] = scale[y];
scale_local[1] = scale[y + 1];
scale_local[2] = scale[y + 2];
scale_local[3] = scale[y + 3];
} else {
switch (M - y) {
case 3:
scale_local[2] = scale[y + 2];
case 2:
scale_local[1] = scale[y + 1];
case 1:
scale_local[0] = scale[y + 0];
default:
break;
}
}
}
if (y + MBLOCK_INT8_OTH > M) {
switch (y + MBLOCK_INT8_OTH - M) {
......
......@@ -53,7 +53,7 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k
add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(group_norm_compute ARM extra SRCS group_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 3. extra kernels
add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(decode_bboxes_compute_arm ARM extra SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -235,7 +235,8 @@ typedef paddle::lite::kernels::arm::DeformableConvCompute<PRECISION(kFloat),
PRECISION(kFloat)>
DeformableConvFp32;
REGISTER_LITE_KERNEL(deformconv2d, kARM, kFloat, kNCHW, DeformableConvFp32, def)
REGISTER_LITE_KERNEL(
deformable_conv, kARM, kFloat, kNCHW, DeformableConvFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/group_norm_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void GroupNormCompute::PrepareForRun() {}
void GroupNormCompute::Run() {
auto& param = this->Param<param_t>();
const float* in = param.x->data<float>();
const float* scale = param.scale->data<float>();
const float* bias = param.bias->data<float>();
float* out = param.out->mutable_data<float>();
float* saved_mean = param.saved_mean->mutable_data<float>();
float* saved_variance = param.saved_variance->mutable_data<float>();
float epsilon = param.epsilon;
int groups = param.groups;
int channels = param.channels;
int n = param.x->dims()[0];
int c = param.x->dims()[1];
int ch_per_group = channels / groups;
int height = param.x->dims()[2];
int width = param.x->dims()[3];
int spatial_size = ch_per_group * height * width;
int ngroup = n * groups;
int cnt = spatial_size >> 4;
int remain = spatial_size % 16;
// compute saved_mean and saved_variance
#pragma omp parallel for
for (int n = 0; n < ngroup; ++n) {
const float* in_p = in + n * spatial_size;
float sum_spatial = 0.f;
float summ_spatial = 0.f;
float32x4_t sum0 = vdupq_n_f32(0.f);
float32x4_t sum1 = vdupq_n_f32(0.f);
float32x4_t sum2 = vdupq_n_f32(0.f);
float32x4_t sum3 = vdupq_n_f32(0.f);
float32x4_t summ0 = vdupq_n_f32(0.f);
float32x4_t summ1 = vdupq_n_f32(0.f);
float32x4_t summ2 = vdupq_n_f32(0.f);
float32x4_t summ3 = vdupq_n_f32(0.f);
for (int i = 0; i < cnt; i++) {
float32x4_t in0 = vld1q_f32(in_p);
float32x4_t in1 = vld1q_f32(in_p + 4);
float32x4_t in2 = vld1q_f32(in_p + 8);
float32x4_t in3 = vld1q_f32(in_p + 12);
sum0 = vaddq_f32(sum0, in0);
summ0 = vmlaq_f32(summ0, in0, in0);
sum1 = vaddq_f32(sum1, in1);
summ1 = vmlaq_f32(summ1, in1, in1);
sum2 = vaddq_f32(sum2, in2);
summ2 = vmlaq_f32(summ2, in2, in2);
sum3 = vaddq_f32(sum3, in3);
summ3 = vmlaq_f32(summ3, in3, in3);
in_p += 16;
}
for (int i = 0; i < remain - 3; i += 4) {
float32x4_t in0 = vld1q_f32(in_p);
sum1 = vaddq_f32(sum1, in0);
summ1 = vmlaq_f32(summ1, in0, in0);
in_p += 4;
}
float sum = 0.0;
float summ = 0.0;
sum0 = vaddq_f32(sum0, sum1);
sum2 = vaddq_f32(sum2, sum3);
summ0 = vaddq_f32(summ0, summ1);
summ2 = vaddq_f32(summ2, summ3);
for (int i = 0; i < remain % 4; i++) {
sum += *in_p;
summ += (*in_p) * (*in_p);
in_p++;
}
sum0 = vaddq_f32(sum0, sum2);
summ0 = vaddq_f32(summ0, summ2);
float32x2_t sum_low = vpadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
float32x2_t sum_high = vpadd_f32(vget_low_f32(summ0), vget_high_f32(summ0));
float32x2_t sum_mix = vpadd_f32(sum_low, sum_high);
sum += vget_lane_f32(sum_mix, 0);
summ += vget_lane_f32(sum_mix, 1);
float mean = sum / spatial_size;
// float variance = summ / spatial_size - mean * mean;
// the flolowing code has higher precision than above comment code
float variance = (summ - mean * mean * spatial_size) / spatial_size;
float std = 1.f / sqrtf(variance + epsilon);
saved_mean[n] = mean;
saved_variance[n] = std;
}
int in_size = height * width;
cnt = in_size >> 4;
remain = in_size % 16;
// compute Group_norm result: out = scale * (in - mean) / std + bias
#pragma omp parallel for
for (int i = 0; i < ngroup; ++i) {
const float* in_p = in + i * spatial_size;
float* out_p = out + i * spatial_size;
int numc = i % groups;
numc *= ch_per_group;
for (int c = 0; c < ch_per_group; c++) {
int chin = numc + c;
const float sstd_val = scale[chin] * saved_variance[i];
const float bias_val = bias[chin];
const float mean_val = saved_mean[i];
const float32x4_t vsstd = vdupq_n_f32(sstd_val);
const float32x4_t vbias = vdupq_n_f32(bias_val);
const float32x4_t vmean = vdupq_n_f32(mean_val);
for (int k = 0; k < cnt; k++) {
float32x4_t in0 = vld1q_f32(in_p);
float32x4_t in1 = vld1q_f32(in_p + 4);
float32x4_t in2 = vld1q_f32(in_p + 8);
float32x4_t in3 = vld1q_f32(in_p + 12);
float32x4_t submean0 = vsubq_f32(in0, vmean);
float32x4_t submean1 = vsubq_f32(in1, vmean);
float32x4_t submean2 = vsubq_f32(in2, vmean);
float32x4_t submean3 = vsubq_f32(in3, vmean);
float32x4_t out0 = vmlaq_f32(vbias, submean0, vsstd);
float32x4_t out1 = vmlaq_f32(vbias, submean1, vsstd);
float32x4_t out2 = vmlaq_f32(vbias, submean2, vsstd);
float32x4_t out3 = vmlaq_f32(vbias, submean3, vsstd);
vst1q_f32(out_p, out0);
vst1q_f32(out_p + 4, out1);
vst1q_f32(out_p + 8, out2);
vst1q_f32(out_p + 12, out3);
in_p += 16;
out_p += 16;
}
for (int k = 0; k < remain - 3; k += 4) {
float32x4_t in0 = vld1q_f32(in_p);
in_p += 4;
float32x4_t submean0 = vsubq_f32(in0, vmean);
float32x4_t out0 = vmlaq_f32(vbias, submean0, vsstd);
vst1q_f32(out_p, out0);
out_p += 4;
}
for (int k = 0; k < remain % 4; k++) {
*out_p = (*in_p - mean_val) * sstd_val + bias_val;
in_p++;
out_p++;
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(group_norm,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GroupNormCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GroupNormCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GroupNormParam;
void PrepareForRun() override;
void Run() override;
virtual ~GroupNormCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -57,6 +57,7 @@ add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS})
add_operator(crop_op extra SRCS crop_op.cc DEPS ${op_DEPS})
add_operator(assign_op extra SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(power_op extra SRCS power_op.cc DEPS ${op_DEPS})
add_operator(group_norm_op extra SRCS group_norm_op.cc DEPS ${op_DEPS})
add_operator(norm_op extra SRCS norm_op.cc DEPS ${op_DEPS})
# 3.extra ops
......@@ -143,7 +144,7 @@ add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_o
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS})
# for deformable-convNet
add_operator(deformable_conv_op basic SRCS deformable_conv_op.cc DEPS ${op_DEPS})
add_operator(deformable_conv_op extra SRCS deformable_conv_op.cc DEPS ${op_DEPS})
# 4. training op
add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
......
......@@ -84,5 +84,5 @@ bool DeformableConvOpLite::InferShapeImpl() const {
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(DeformableConv2d,
REGISTER_LITE_OP(deformable_conv,
paddle::lite::operators::DeformableConvOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/group_norm_op.h"
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace operators {
bool GroupNormOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.scale);
CHECK_OR_FALSE(param_.bias);
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.saved_mean);
CHECK_OR_FALSE(param_.saved_variance);
auto x_dims = param_.x->dims();
auto scale_dims = param_.scale->dims();
auto bias_dims = param_.bias->dims();
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
<< "Input X must have 2 to 5 dimensions.";
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f";
CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f";
CHECK_EQ(param_.channels, x_dims[1])
<< "Input channels must be equal input_shape[1]";
CHECK_EQ(param_.channels % param_.groups, 0)
<< "channels must be divide groups";
return true;
}
bool GroupNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0];
int64_t num = param_.channels / param_.groups;
param_.saved_mean->Resize({batch_size * num});
param_.saved_variance->Resize({batch_size * num});
param_.out->Resize(x_dims);
return true;
}
bool GroupNormOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.scale =
scope->FindVar(op_desc.Input("Scale").front())->GetMutable<Tensor>();
param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
param_.saved_mean =
scope->FindVar(op_desc.Output("SavedMean").front())->GetMutable<Tensor>();
param_.saved_variance =
scope->FindVar(op_desc.Output("SavedVariance").front())
->GetMutable<Tensor>();
param_.out =
scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.epsilon = op_desc.GetAttr<float>("epsilon");
param_.groups = op_desc.GetAttr<int>("groups");
param_.channels = op_desc.GetAttr<int>("channels");
return true;
}
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
REGISTER_LITE_OP(group_norm, paddle::lite::operators::GroupNormOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class GroupNormOp : public OpLite {
public:
GroupNormOp() {}
explicit GroupNormOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "group_norm"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->output_shape = ch->DimToStr(param_.out->dims());
// ch->remark = "";
auto x_dims = param_.x->dims();
auto nc = x_dims[0] * x_dims[1];
auto hw = x_dims[2] * x_dims[3];
auto nchw = x_dims.production();
ch->macs = 5.f * nchw + 3.f * (nc + hw);
}
#endif
private:
mutable GroupNormParam param_;
};
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
......@@ -1429,6 +1429,19 @@ struct InstanceNormParam : ParamBase {
lite::Tensor* saved_variance{};
float epsilon;
};
/// --------------------- group_norm operators --------------------
struct GroupNormParam : ParamBase {
lite::Tensor* x{};
lite::Tensor* out{};
lite::Tensor* bias{};
lite::Tensor* scale{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_variance{};
float epsilon;
int groups;
int channels;
};
/// --------------------- grid sampler operators --------------------
struct GridSamplerParam : ParamBase {
lite::Tensor* x{};
......
......@@ -17,6 +17,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_group_norm_compute SRCS group_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class GroupNormComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "x";
std::string y_ = "y";
std::string saved_mean_ = "saved_mean";
std::string saved_variance_ = "saved_variance";
std::string scale_ = "scale";
std::string bias_ = "bias";
DDim dims_{{4, 5, 19, 19}};
float epsilon_ = 1e-5f;
int groups_ = 1;
int channels_ = dims_[1];
public:
GroupNormComputeTest(const Place& place,
const std::string& alias,
DDim dims,
float epsilon,
int groups,
int channels)
: TestCase(place, alias),
dims_(dims),
epsilon_(epsilon),
groups_(groups),
channels_(channels) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(x_);
auto scale = scope->FindTensor(scale_);
auto bias = scope->FindTensor(bias_);
auto y = scope->NewTensor(y_);
auto saved_mean = scope->NewTensor(saved_mean_);
auto saved_variance = scope->NewTensor(saved_variance_);
CHECK(y);
CHECK(saved_mean);
CHECK(saved_variance);
DDim saved_dim({dims_[0] * groups_});
y->Resize(dims_);
saved_mean->Resize(saved_dim);
saved_variance->Resize(saved_dim);
auto x_data = x->data<float>();
auto scale_data = scale->data<float>();
auto bias_data = bias->data<float>();
auto y_data = y->mutable_data<float>();
auto saved_mean_data = saved_mean->mutable_data<float>();
auto saved_variance_data = saved_variance->mutable_data<float>();
int n = x->dims()[0];
int ch_per_group = channels_ / groups_;
CHECK_EQ(x->dims()[1], channels_);
int spatial_size = ch_per_group * x->dims()[2] * x->dims()[3];
// compute mean
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum += x_ptr[j];
}
saved_mean_data[i] = sum / spatial_size;
}
// compute variance
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum +=
(x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
}
saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_);
}
int in_size = x->dims()[2] * x->dims()[3];
// compute out
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float* y_ptr = y_data + i * spatial_size;
int c_num = i % groups_;
for (int c = 0; c < ch_per_group; c++) {
int chin = c_num * ch_per_group + c;
float scale_val = scale_data[chin];
float bias_val = bias_data[chin];
const float* x_ch_ptr = x_ptr + c * in_size;
float* y_ch_ptr = y_ptr + c * in_size;
for (int j = 0; j < in_size; j++) {
y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) *
saved_variance_data[i] +
bias_val;
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("group_norm");
op_desc->SetInput("X", {x_});
op_desc->SetInput("Bias", {bias_});
op_desc->SetInput("Scale", {scale_});
op_desc->SetOutput("Y", {y_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
op_desc->SetAttr("epsilon", epsilon_);
op_desc->SetAttr("groups", groups_);
op_desc->SetAttr("channels", channels_);
}
void PrepareData() override {
std::vector<float> x(dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, dims_.production());
DDim scale_bias_dims{{dims_[1]}};
std::vector<float> scale(scale_bias_dims.production());
fill_data_rand(scale.data(), -1.f, 1.f, scale_bias_dims.production());
std::vector<float> bias(scale_bias_dims.production());
fill_data_rand(bias.data(), -1.f, 1.f, scale_bias_dims.production());
SetCommonTensor(x_, dims_, x.data());
SetCommonTensor(scale_, scale_bias_dims, scale.data(), {}, true);
SetCommonTensor(bias_, scale_bias_dims, bias.data(), {}, true);
}
};
void TestGroupNorm(Place place,
float abs_error = 6e-5,
std::vector<std::string> ignored_outs = {}) {
for (auto& n : {1, 3, 16}) {
for (auto& c : {1}) {
for (auto& h : {1, 16, 33, 56}) {
for (auto& w : {1, 17, 55}) {
for (auto& groups : {1, 2, 4}) {
if (c % groups != 0) {
continue;
}
DDim dim_in({n, c, h, w});
float epsilon = 1e-5f;
std::unique_ptr<arena::TestCase> tester(new GroupNormComputeTest(
place, "def", dim_in, epsilon, groups, c));
#ifdef LITE_WITH_ARM
if (place == TARGET(kARM)) {
auto& ctx = tester->context()->As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 4);
}
#endif
arena::Arena arena(std::move(tester), place, abs_error);
if (!arena.TestPrecision(ignored_outs)) {
LOG(ERROR) << "run n: " << n << ", c: " << c << ", h: " << h
<< ", w: " << w;
return;
}
}
}
}
}
}
}
TEST(GroupNorm, precision) {
Place place;
float abs_error = 6e-5;
std::vector<std::string> ignored_outs = {};
#ifdef LITE_WITH_ARM
place = TARGET(kARM);
#else
return;
#endif
TestGroupNorm(place, abs_error, ignored_outs);
}
} // namespace lite
} // namespace paddle
......@@ -342,7 +342,7 @@ TEST(TestDeformableConvRand, test_deformable_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 16}) {
for (auto& g : {1, 2}) {
for (auto& g : {1}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册