未验证 提交 9a3552db 编写于 作者: Y yiicy 提交者: GitHub

[ARM] add instance norm op and ut, test=develop (#2578)

上级 c64036ca
...@@ -477,17 +477,23 @@ void nearest_interp(const float* src, ...@@ -477,17 +477,23 @@ void nearest_interp(const float* src,
float scale_h_new = (with_align) float scale_h_new = (with_align)
? (static_cast<float>(h_in - 1) / (h_out - 1)) ? (static_cast<float>(h_in - 1) / (h_out - 1))
: (static_cast<float>(h_in) / (h_out)); : (static_cast<float>(h_in) / (h_out));
if (with_align) {
#pragma omp parallel for collapse(2) schedule(static) for (int h = 0; h < h_out; ++h) {
for (int h = 0; h < h_out; ++h) { float* dst_p = dst + h * w_out;
for (int w = 0; w < w_out; ++w) { int near_y = static_cast<int>(scale_h_new * h + 0.5);
int near_x = (with_align) ? static_cast<int>(scale_w_new * w + 0.5) for (int w = 0; w < w_out; ++w) {
: static_cast<int>(scale_w_new * w); int near_x = static_cast<int>(scale_w_new * w + 0.5);
int near_y = (with_align) ? static_cast<int>(scale_h_new * h + 0.5) *dst_p++ = src[near_y * w_in + near_x];
: static_cast<int>(scale_h_new * h); }
near_x = near_x < 0 ? 0 : near_x; }
near_y = near_y < 0 ? 0 : near_y; } else {
dst[h * w_out + w] = src[near_y * w_in + near_x]; for (int h = 0; h < h_out; ++h) {
float* dst_p = dst + h * w_out;
int near_y = static_cast<int>(scale_h_new * h);
for (int w = 0; w < w_out; ++w) {
int near_x = static_cast<int>(scale_w_new * w);
*dst_p++ = src[near_y * w_in + near_x];
}
} }
} }
} }
...@@ -544,8 +550,10 @@ void interpolate(lite::Tensor* X, ...@@ -544,8 +550,10 @@ void interpolate(lite::Tensor* X,
int out_w = Out->dims()[3]; int out_w = Out->dims()[3];
int spatial_in = in_h * in_w; int spatial_in = in_h * in_w;
int spatial_out = out_h * out_w; int spatial_out = out_h * out_w;
for (int i = 0; i < count; ++i) {
if ("Bilinear" == interpolate_type) { if ("Bilinear" == interpolate_type) {
#pragma omp parallel for
for (int i = 0; i < count; ++i) {
bilinear_interp(din + spatial_in * i, bilinear_interp(din + spatial_in * i,
in_w, in_w,
in_h, in_h,
...@@ -555,7 +563,10 @@ void interpolate(lite::Tensor* X, ...@@ -555,7 +563,10 @@ void interpolate(lite::Tensor* X,
1.f / width_scale, 1.f / width_scale,
1.f / height_scale, 1.f / height_scale,
with_align); with_align);
} else if ("Nearest" == interpolate_type) { }
} else if ("Nearest" == interpolate_type) {
#pragma omp parallel for
for (int i = 0; i < count; ++i) {
nearest_interp(din + spatial_in * i, nearest_interp(din + spatial_in * i,
in_w, in_w,
in_h, in_h,
......
...@@ -41,6 +41,7 @@ add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc D ...@@ -41,6 +41,7 @@ add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc D
add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 2.other basic kernels: basic kernels that not used in basic models ## 2.other basic kernels: basic kernels that not used in basic models
add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/instance_norm_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void InstanceNormCompute::PrepareForRun() {}
void InstanceNormCompute::Run() {
auto& param = this->Param<param_t>();
const float* in = param.x->data<float>();
const float* scale = param.scale->data<float>();
const float* bias = param.bias->data<float>();
float* out = param.out->mutable_data<float>();
float* saved_mean = param.saved_mean->mutable_data<float>();
float* saved_variance = param.saved_variance->mutable_data<float>();
float epsilon = param.epsilon;
int n = param.x->dims()[0];
int c = param.x->dims()[1];
int nc = n * c;
int height = param.x->dims()[2];
int width = param.x->dims()[3];
int spatial_size = height * width;
// compute saved_mean and saved_variance
#pragma omp parallel for
for (int i = 0; i < nc; ++i) {
const float* in_p = in + i * spatial_size;
float sum_spatial = 0.f;
float summ_spatial = 0.f;
for (int h = 0; h < height; ++h) {
int w = width;
float32x4_t sum0 = vdupq_n_f32(0.f);
float32x4_t sum1 = vdupq_n_f32(0.f);
float32x4_t sum2 = vdupq_n_f32(0.f);
float32x4_t sum3 = vdupq_n_f32(0.f);
float32x4_t summ0 = vdupq_n_f32(0.f);
float32x4_t summ1 = vdupq_n_f32(0.f);
float32x4_t summ2 = vdupq_n_f32(0.f);
float32x4_t summ3 = vdupq_n_f32(0.f);
float32x4_t in0, in1, in2, in3;
for (; w > 15; w -= 16) {
in0 = vld1q_f32(in_p);
in1 = vld1q_f32(in_p + 4);
in2 = vld1q_f32(in_p + 8);
in3 = vld1q_f32(in_p + 12);
sum0 = vaddq_f32(sum0, in0);
sum1 = vaddq_f32(sum1, in1);
summ0 = vmlaq_f32(summ0, in0, in0);
summ1 = vmlaq_f32(summ1, in1, in1);
sum2 = vaddq_f32(sum2, in2);
sum3 = vaddq_f32(sum3, in3);
summ2 = vmlaq_f32(summ2, in2, in2);
summ3 = vmlaq_f32(summ3, in3, in3);
in_p += 16;
}
for (; w > 7; w -= 8) {
in0 = vld1q_f32(in_p);
in1 = vld1q_f32(in_p + 4);
sum0 = vaddq_f32(sum0, in0);
sum1 = vaddq_f32(sum1, in1);
summ0 = vmlaq_f32(summ0, in0, in0);
summ1 = vmlaq_f32(summ1, in1, in1);
in_p += 8;
}
for (; w > 3; w -= 4) {
in0 = vld1q_f32(in_p);
sum0 = vaddq_f32(sum0, in0);
summ0 = vmlaq_f32(summ0, in0, in0);
in_p += 4;
}
float sum = 0.f;
float summ = 0.f;
for (; w > 0; w--) {
sum += *in_p;
summ += (*in_p) * (*in_p);
in_p++;
}
sum0 = vaddq_f32(sum0, sum1);
sum2 = vaddq_f32(sum2, sum3);
summ0 = vaddq_f32(summ0, summ1);
summ2 = vaddq_f32(summ2, summ3);
sum0 = vaddq_f32(sum0, sum2);
summ0 = vaddq_f32(summ0, summ2);
float32x2_t sum_low = vpadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
float32x2_t sum_high =
vpadd_f32(vget_low_f32(summ0), vget_high_f32(summ0));
float32x2_t sum_mix = vpadd_f32(sum_low, sum_high);
sum += vget_lane_f32(sum_mix, 0);
summ += vget_lane_f32(sum_mix, 1);
sum_spatial += sum;
summ_spatial += summ;
}
float mean = sum_spatial / spatial_size;
// float variance = summ / spatial_size - mean * mean;
// the flolowing code has higher precision than above comment code
float variance = (summ_spatial - mean * mean * spatial_size) / spatial_size;
float std = 1.f / sqrtf(variance + epsilon);
saved_mean[i] = mean;
saved_variance[i] = std;
}
// compute instance_norm result: out = scale * (in - mean) / std + bias
#pragma omp parallel for
for (int i = 0; i < nc; ++i) {
const float* in_p = in + i * spatial_size;
float* out_p = out + i * spatial_size;
int j = spatial_size;
const float sstd_val = scale[i % c] * saved_variance[i];
const float bias_val = bias[i % c];
const float mean_val = saved_mean[i];
const float32x4_t vsstd = vdupq_n_f32(sstd_val);
const float32x4_t vbias = vdupq_n_f32(bias_val);
const float32x4_t vmean = vdupq_n_f32(mean_val);
float32x4_t in0, in1, submean0, submean1, out0, out1;
for (; j > 7; j -= 8) {
in0 = vld1q_f32(in_p);
in1 = vld1q_f32(in_p + 4);
submean0 = vsubq_f32(in0, vmean);
submean1 = vsubq_f32(in1, vmean);
out0 = vmlaq_f32(vbias, submean0, vsstd);
out1 = vmlaq_f32(vbias, submean1, vsstd);
vst1q_f32(out_p, out0);
vst1q_f32(out_p + 4, out1);
in_p += 8;
out_p += 8;
}
for (; j > 3; j -= 4) {
in0 = vld1q_f32(in_p);
submean0 = vsubq_f32(in0, vmean);
out0 = vmlaq_f32(vbias, submean0, vsstd);
vst1q_f32(out_p, out0);
in_p += 4;
out_p += 4;
}
for (; j > 0; j--) {
*out_p = (*in_p - mean_val) * sstd_val + bias_val;
in_p++;
out_p++;
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(instance_norm,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::InstanceNormCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class InstanceNormCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::InstanceNormParam;
void PrepareForRun() override;
void Run() override;
virtual ~InstanceNormCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -47,6 +47,7 @@ add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_act ...@@ -47,6 +47,7 @@ add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_act
add_operator(io_copy_once_op basic SRCS io_copy_once_op.cc DEPS io_copy_op ${op_DEPS}) add_operator(io_copy_once_op basic SRCS io_copy_once_op.cc DEPS io_copy_op ${op_DEPS})
add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS}) add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS})
add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS}) add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS})
add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS})
add_operator(graph_op basic SRCS graph_op.cc DEPS ${op_DEPS}) add_operator(graph_op basic SRCS graph_op.cc DEPS ${op_DEPS})
# 2.basic ops not used in basic models # 2.basic ops not used in basic models
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/instance_norm_op.h"
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace operators {
bool InstanceNormOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.scale);
CHECK_OR_FALSE(param_.bias);
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.saved_mean);
CHECK_OR_FALSE(param_.saved_variance);
auto x_dims = param_.x->dims();
auto scale_dims = param_.scale->dims();
auto bias_dims = param_.bias->dims();
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
<< "Input X must have 2 to 5 dimensions.";
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f";
CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f";
return true;
}
bool InstanceNormOp::InferShape() const {
auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0];
int64_t channel_size = x_dims[1];
param_.saved_mean->Resize({batch_size * channel_size});
param_.saved_variance->Resize({batch_size * channel_size});
param_.out->Resize(x_dims);
return true;
}
bool InstanceNormOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.scale =
scope->FindVar(op_desc.Input("Scale").front())->GetMutable<Tensor>();
param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
param_.saved_mean =
scope->FindVar(op_desc.Output("SavedMean").front())->GetMutable<Tensor>();
param_.saved_variance =
scope->FindVar(op_desc.Output("SavedVariance").front())
->GetMutable<Tensor>();
param_.out =
scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.epsilon = op_desc.GetAttr<float>("epsilon");
return true;
}
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
REGISTER_LITE_OP(instance_norm, paddle::lite::operators::InstanceNormOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class InstanceNormOp : public OpLite {
public:
InstanceNormOp() {}
explicit InstanceNormOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "instance_norm"; }
private:
mutable InstanceNormParam param_;
};
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
...@@ -1090,6 +1090,17 @@ struct CollectFpnProposalsParam { ...@@ -1090,6 +1090,17 @@ struct CollectFpnProposalsParam {
int post_nms_topN{}; int post_nms_topN{};
}; };
/// --------------------- instance_norm operators --------------------
struct InstanceNormParam {
lite::Tensor* x{};
lite::Tensor* out{};
lite::Tensor* bias{};
lite::Tensor* scale{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_variance{};
float epsilon;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE ...@@ -14,6 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class InstanceNormComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "y";
std::string saved_mean_ = "saved_mean";
std::string saved_variance_ = "saved_variance";
std::string scale_ = "scale";
std::string bias_ = "bias";
DDim dims_{{4, 5, 19, 19}};
float epsilon_ = 1e-5f;
public:
InstanceNormComputeTest(const Place& place,
const std::string& alias,
DDim dims,
float epsilon)
: TestCase(place, alias), dims_(dims), epsilon_(epsilon) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_);
auto scale = scope->FindTensor(scale_);
auto bias = scope->FindTensor(bias_);
auto out = scope->NewTensor(output_);
auto saved_mean = scope->NewTensor(saved_mean_);
auto saved_variance = scope->NewTensor(saved_variance_);
CHECK(out);
CHECK(saved_mean);
CHECK(saved_variance);
DDim saved_dim({dims_[0] * dims_[1]});
out->Resize(dims_);
saved_mean->Resize(saved_dim);
saved_variance->Resize(saved_dim);
auto x_data = x->data<float>();
auto scale_data = scale->data<float>();
auto bias_data = bias->data<float>();
auto out_data = out->mutable_data<float>();
auto saved_mean_data = saved_mean->mutable_data<float>();
auto saved_variance_data = saved_variance->mutable_data<float>();
int n = x->dims()[0];
int c = x->dims()[1];
int spatial_size = x->dims()[2] * x->dims()[3];
// compute mean
for (int i = 0; i < n * c; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum += x_ptr[j];
}
saved_mean_data[i] = sum / spatial_size;
}
// compute variance
for (int i = 0; i < n * c; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum +=
(x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
}
saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_);
}
// compute out
for (int i = 0; i < n * c; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float* out_ptr = out_data + i * spatial_size;
float scale_val = scale_data[i % c];
float bias_val = bias_data[i % c];
for (int j = 0; j < spatial_size; ++j) {
out_ptr[j] = scale_val * (x_ptr[j] - saved_mean_data[i]) *
saved_variance_data[i] +
bias_val;
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("instance_norm");
op_desc->SetInput("X", {input_});
op_desc->SetInput("Bias", {bias_});
op_desc->SetInput("Scale", {scale_});
op_desc->SetOutput("Y", {output_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
op_desc->SetAttr("epsilon", epsilon_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
DDim scale_dim{{dims_[1]}};
std::vector<float> scale(scale_dim.production());
fill_data_rand(scale.data(), -1.f, 1.f, scale_dim.production());
std::vector<float> bias(scale_dim.production());
fill_data_rand(bias.data(), -1.f, 1.f, scale_dim.production());
SetCommonTensor(input_, dims_, din.data());
SetCommonTensor(scale_, scale_dim, scale.data());
SetCommonTensor(bias_, scale_dim, bias.data());
}
};
void test_instance_norm(Place place) {
for (auto& n : {1, 3, 16}) {
for (auto& c : {1, 4, 16}) {
for (auto& h : {1, 16, 33, 56}) {
for (auto& w : {1, 17, 34, 55}) {
DDim dim_in({n, c, h, w});
float epsilon = 1e-5f;
std::unique_ptr<arena::TestCase> tester(
new InstanceNormComputeTest(place, "def", dim_in, epsilon));
#ifdef LITE_WITH_ARM
auto& ctx = tester->context()->As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 4);
#endif
arena::Arena arena(std::move(tester), place, 6e-5);
if (!arena.TestPrecision()) {
LOG(ERROR) << "run n: " << n << ", c: " << c << ", h: " << h
<< ", w: " << w;
return;
}
}
}
}
}
}
TEST(InstanceNorm, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_instance_norm(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册