未验证 提交 bc6d5adc 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] bn unit test (#2706)

test=develop
上级 a29c84a2
......@@ -37,30 +37,36 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale");
CHECK(scale_type->precision() == PRECISION(kFloat));
CHECK(scale_type->layout() == DATALAYOUT(kNCHW));
auto scale = scope->FindMutableTensor(scale_name);
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto mean_name = op_info->Input("Mean").front();
auto mean_type = kernel->GetInputDeclType("Mean");
CHECK(mean_type->precision() == PRECISION(kFloat));
CHECK(mean_type->layout() == DATALAYOUT(kNCHW));
auto mean = scope->FindMutableTensor(mean_name);
auto variance_name = op_info->Input("Variance").front();
auto variance_type = kernel->GetInputDeclType("Variance");
CHECK(variance_type->precision() == PRECISION(kFloat));
CHECK(variance_type->layout() == DATALAYOUT(kNCHW));
auto variance = scope->FindMutableTensor(variance_name);
auto y_name = op_info->Output("Y").front();
auto y_type = kernel->GetOutputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto epsilon = op_info->GetAttr<float>("epsilon");
// X node
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/batch_norm_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename dtype>
void batch_norm_ref(const std::shared_ptr<operators::BatchNormOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto y = scope->FindVar(op_info->Output("Y").front())->GetMutable<Tensor>();
auto bias =
scope->FindVar(op_info->Input("Bias").front())->GetMutable<Tensor>();
auto scale =
scope->FindVar(op_info->Input("Scale").front())->GetMutable<Tensor>();
auto mean =
scope->FindVar(op_info->Input("Mean").front())->GetMutable<Tensor>();
auto variance =
scope->FindVar(op_info->Input("Variance").front())->GetMutable<Tensor>();
auto x_data = x->data<dtype>();
auto y_data = y->mutable_data<dtype>();
auto scale_data = scale->mutable_data<dtype>();
auto bias_data = bias->mutable_data<dtype>();
auto mean_data = mean->mutable_data<dtype>();
auto variance_data = variance->mutable_data<dtype>();
DDim x_dims = x->dims();
float epsilon = op_info->GetAttr<float>("epsilon");
auto data_layout = op_info->GetAttr<std::string>("data_layout");
bool global_stats = op_info->GetAttr<bool>("use_global_stats");
if (global_stats) {
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
if (data_layout == "NCHW") {
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
} else {
LOG(FATAL) << "Unknown storage order: " << data_layout;
}
auto x_ptr = x_data;
auto y_ptr = y_data;
for (int o = 0; o < outer_size; o++) {
for (int c = 0; c < channel_size; c++) {
for (int i = 0; i < inner_size; i++) {
dtype norm_x =
(*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon);
*y_ptr = norm_x * scale_data[c] + bias_data[c];
x_ptr++;
y_ptr++;
}
}
}
}
}
void test_batch_norm(int bs, int ic, int ih, int iw, float epsilon) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
std::string scale_var_name = "scale";
std::string bias_var_name = "bias";
std::string mean_var_name = "mean";
std::string variance_var_name = "variance";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* scale = scope.Var(scale_var_name)->GetMutable<Tensor>();
auto* bias = scope.Var(bias_var_name)->GetMutable<Tensor>();
auto* mean = scope.Var(mean_var_name)->GetMutable<Tensor>();
auto* variance = scope.Var(variance_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
scale->Resize({ic});
bias->Resize({ic});
mean->Resize({ic});
variance->Resize({ic});
// initialize input&output data
FillTensor<float>(x);
FillTensor<float>(scale);
FillTensor<float>(bias);
FillTensor<float>(mean);
// variance > 0
FillTensor<float>(variance, 1.f, 5.f);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("batch_norm");
opdesc.SetInput("X", {x_var_name});
opdesc.SetInput("Scale", {scale_var_name});
opdesc.SetInput("Bias", {bias_var_name});
opdesc.SetInput("Mean", {mean_var_name});
opdesc.SetInput("Variance", {variance_var_name});
opdesc.SetOutput("Y", {out_var_name});
opdesc.SetAttr("is_test", 1);
opdesc.SetAttr("use_global_stats", true);
opdesc.SetAttr("epsilon", epsilon);
opdesc.SetAttr("momentum", 0.9f);
opdesc.SetAttr("data_layout", std::string("NCHW"));
// create and convert op to XPU model, then run it on XPU
auto op = CreateOp<operators::BatchNormOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
batch_norm_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, batch_norm) {
for (auto bs : {1, 3}) {
for (auto ic : {2, 3}) {
for (auto ih : {4}) {
for (auto iw : {5}) {
for (auto epsilon : {1e-5f}) {
test_batch_norm(bs, ic, ih, iw, epsilon);
}
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(batch_norm);
USE_XPU_BRIDGE(batch_norm);
......@@ -31,6 +31,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class BatchNormComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "batch_norm";
std::string input_ = "x";
std::string scale_ = "scale";
std::string bias_ = "bias";
std::string mean_ = "mean";
std::string variance_ = "variance";
std::string output_ = "y";
std::string mean_out_ = "mean_out";
std::string saved_mean_ = "saved_mean";
std::string variance_out_ = "variance_out";
std::string saved_variance_ = "saved_variance";
DDim dims_{{1, 2, 3, 4}};
bool use_global_stats_ = false;
float momentum_ = 0.9;
float epsilon_ = 1e-5f;
std::string data_layout_ = "NCHW";
int is_test_ = 1;
public:
BatchNormComputeTest(const Place& place,
const std::string& alias,
DDim dims,
float epsilon)
: TestCase(place, alias), dims_(dims), epsilon_(epsilon) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_);
auto scale = scope->FindTensor(scale_);
auto bias = scope->FindTensor(bias_);
auto mean = scope->FindTensor(mean_);
auto variance = scope->FindTensor(variance_);
auto y = scope->NewTensor(output_);
auto mean_out = scope->NewTensor(mean_out_);
auto variance_out = scope->NewTensor(variance_out_);
auto saved_mean = scope->NewTensor(saved_mean_);
auto saved_variance = scope->NewTensor(saved_variance_);
CHECK(y);
CHECK(mean_out);
CHECK(variance_out);
CHECK(saved_mean);
CHECK(saved_variance);
y->Resize(dims_);
int64_t channel_size = 0;
if (data_layout_ == "NCHW") {
channel_size = dims_[1];
} else {
LOG(FATAL) << "Unknown storage order: " << data_layout_;
}
mean_out->Resize({channel_size});
variance_out->Resize({channel_size});
saved_mean->Resize({channel_size});
saved_variance->Resize({channel_size});
auto x_data = x->data<float>();
auto y_data = y->mutable_data<float>();
auto scale_data = scale->data<float>();
auto bias_data = bias->data<float>();
auto mean_data = mean->data<float>();
auto variance_data = variance->data<float>();
int64_t outer_size = 0;
int64_t inner_size = 0;
if (data_layout_ == "NCHW") {
outer_size = dims_[0];
inner_size = dims_.Slice(2, dims_.size()).production();
} else {
LOG(FATAL) << "Unknown storage order: " << data_layout_;
}
auto x_ptr = x_data;
auto y_ptr = y_data;
for (int o = 0; o < outer_size; o++) {
for (int c = 0; c < channel_size; c++) {
for (int i = 0; i < inner_size; i++) {
float norm_x =
(*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon_);
*y_ptr = norm_x * scale_data[c] + bias_data[c];
x_ptr++;
y_ptr++;
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_});
op_desc->SetInput("Bias", {bias_});
op_desc->SetInput("Scale", {scale_});
op_desc->SetInput("Mean", {mean_});
op_desc->SetInput("Variance", {variance_});
op_desc->SetOutput("Y", {output_});
op_desc->SetOutput("MeanOut", {mean_out_});
op_desc->SetOutput("VarianceOut", {variance_out_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
op_desc->SetAttr("epsilon", epsilon_);
op_desc->SetAttr("momentum", momentum_);
op_desc->SetAttr("use_global_stats", use_global_stats_);
op_desc->SetAttr("data_layout", data_layout_);
op_desc->SetAttr("is_test", is_test_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
DDim scale_dim({dims_[1]});
std::vector<float> scale(scale_dim.production());
fill_data_rand(scale.data(), -1.f, 1.f, scale_dim.production());
std::vector<float> bias(scale_dim.production());
fill_data_rand(bias.data(), -1.f, 1.f, scale_dim.production());
std::vector<float> mean(scale_dim.production());
fill_data_rand(mean.data(), -1.f, 1.f, scale_dim.production());
std::vector<float> variance(scale_dim.production());
fill_data_rand(variance.data(), 0.f, 1.f, scale_dim.production());
SetCommonTensor(input_, dims_, din.data());
SetCommonTensor(scale_, scale_dim, scale.data());
SetCommonTensor(bias_, scale_dim, bias.data());
SetCommonTensor(mean_, scale_dim, mean.data());
SetCommonTensor(variance_, scale_dim, variance.data());
}
};
TEST(BatchNorm, precision) {
LOG(INFO) << "test BatchNorm op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {5, 6, 7, 8}}) {
for (auto epsilon : {1e-5f}) {
std::unique_ptr<arena::TestCase> tester(
new BatchNormComputeTest(place, "def", DDim(dims), epsilon));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(
{"mean_out", "saved_mean", "variance_out", "saved_variance"});
}
}
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册