未验证 提交 59d079e8 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] enhance unittest for bn, transpose (#2716)

test=develop
上级 a1527e80
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/batch_norm_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void batch_norm_ref(const std::shared_ptr<operators::BatchNormOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto y = scope->FindVar(op_info->Output("Y").front())->GetMutable<Tensor>();
auto bias =
scope->FindVar(op_info->Input("Bias").front())->GetMutable<Tensor>();
auto scale =
scope->FindVar(op_info->Input("Scale").front())->GetMutable<Tensor>();
auto mean =
scope->FindVar(op_info->Input("Mean").front())->GetMutable<Tensor>();
auto variance =
scope->FindVar(op_info->Input("Variance").front())->GetMutable<Tensor>();
auto x_data = x->data<dtype>();
auto y_data = y->mutable_data<dtype>();
auto scale_data = scale->mutable_data<dtype>();
auto bias_data = bias->mutable_data<dtype>();
auto mean_data = mean->mutable_data<dtype>();
auto variance_data = variance->mutable_data<dtype>();
DDim x_dims = x->dims();
float epsilon = op_info->GetAttr<float>("epsilon");
float momentum = op_info->GetAttr<float>("momentum");
auto data_layout = op_info->GetAttr<std::string>("data_layout");
bool global_stats = op_info->GetAttr<bool>("use_global_stats");
if (global_stats) {
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
if (data_layout == "NCHW") {
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
} else {
LOG(FATAL) << "Unknown storage order: " << data_layout;
}
auto x_ptr = x_data;
auto y_ptr = y_data;
for (int o = 0; o < outer_size; o++) {
for (int c = 0; c < channel_size; c++) {
for (int i = 0; i < inner_size; i++) {
dtype norm_x =
(*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon);
*y_ptr = norm_x * scale_data[c] + bias_data[c];
x_ptr++;
y_ptr++;
}
}
}
}
}
void test_batch_norm(
int bs, int ic, int ih, int iw, float epsilon, float momentum) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
std::string scale_var_name = "scale";
std::string bias_var_name = "bias";
std::string mean_var_name = "mean";
std::string variance_var_name = "variance";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* scale = scope.Var(scale_var_name)->GetMutable<Tensor>();
auto* bias = scope.Var(bias_var_name)->GetMutable<Tensor>();
auto* mean = scope.Var(mean_var_name)->GetMutable<Tensor>();
auto* variance = scope.Var(variance_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
scale->Resize({ic});
bias->Resize({ic});
mean->Resize({ic});
variance->Resize({ic});
// initialize input&output data
FillTensor<float, int>(x);
FillTensor<float, int>(scale);
FillTensor<float, int>(bias);
FillTensor<float, int>(mean);
// variance > 0
FillTensor<float, int>(variance, 1.f, 5.f);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("batch_norm");
opdesc.SetInput("X", {x_var_name});
opdesc.SetInput("Scale", {scale_var_name});
opdesc.SetInput("Bias", {bias_var_name});
opdesc.SetInput("Mean", {mean_var_name});
opdesc.SetInput("Variance", {variance_var_name});
opdesc.SetOutput("Y", {out_var_name});
opdesc.SetAttr("is_test", 1);
opdesc.SetAttr("use_global_stats", true);
opdesc.SetAttr("epsilon", epsilon);
opdesc.SetAttr("momentum", momentum);
opdesc.SetAttr("data_layout", std::string("NCHW"));
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::BatchNormOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
batch_norm_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, batch_norm) {
for (auto bs : {1, 4, 7}) {
for (auto ic : {1, 4, 7}) {
for (auto ih : {1, 4, 7}) {
for (auto iw : {1, 4, 7}) {
for (auto epsilon : {1e-4f, 1e-5f}) {
for (auto momentum : {0.9f, 0.99f}) {
test_batch_norm(bs, ic, ih, iw, epsilon, momentum);
}
}
}
}
}
}
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(batch_norm);
USE_NPU_BRIDGE(batch_norm);
......@@ -37,7 +37,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Input("Out").front();
auto out_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/transpose_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
int data_index(std::vector<int> pos, DDimLite dims) {
int d1 = dims[1];
int d2 = dims[2];
int d3 = dims[3];
return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1;
}
std::vector<int> pos_trans(std::vector<int> in_pos, std::vector<int> axis) {
std::vector<int> out_pos(in_pos.size());
for (int i = 0; i < axis.size(); i++) {
out_pos[axis[i]] = in_pos[i];
}
return out_pos;
}
void transpose_ref(const std::shared_ptr<operators::TransposeOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto input =
scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto output =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_dims = input->dims();
auto y_dims = output->dims();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>();
int input_n = x_dims[0];
int input_c = x_dims[1];
int input_h = x_dims[2];
int input_w = x_dims[3];
int output_n = y_dims[0];
int output_c = y_dims[1];
int output_h = y_dims[2];
int output_w = y_dims[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
std::vector<int> in_pos{n, c, h, w};
std::vector<int> out_pos = pos_trans(in_pos, axis);
int in_index = data_index(in_pos, x_dims);
int out_index = data_index(out_pos, y_dims);
output_data[out_index] = input_data[in_index];
}
}
}
}
}
void test_transpose(int bs, int ic, int ih, int iw, std::vector<int> axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("transpose");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::TransposeOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
transpose_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, transpose) {
#if 0
for (auto bs : {1, 4, 7}) {
for (auto ic : {1, 4, 7}) {
for (auto ih : {1, 4, 7}) {
for (auto iw : {1, 4, 7}) {
for (auto axis : {std::vector<int>{0, 1, 2, 3},
std::vector<int>{0, 1, 3, 2},
std::vector<int>{0, 3, 1, 2},
std::vector<int>{1, 2, 3, 0},
std::vector<int>{3, 2, 1, 0},
std::vector<int>{2, 3, 1, 0}}) {
test_transpose(bs, ic, ih, iw, axis);
}
}
}
}
}
#endif
test_transpose(2, 3, 4, 5, std::vector<int>{0, 1, 3, 2});
// test_transpose(2, 3, 4, 5, std::vector<int>{0, 1, 2, 3});
// test_transpose(2, 2, 2, 2, std::vector<int>{0,1,3,2});
// test_transpose(1, 1, 2, 2, std::vector<int>{0,1,3,2});
// test_transpose(1, 1, 1, 2, std::vector<int>{0,1,2,3});
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(transpose);
USE_NPU_BRIDGE(transpose);
USE_LITE_OP(transpose2);
USE_NPU_BRIDGE(transpose2);
......@@ -25,13 +25,13 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${npu_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${npu_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -159,6 +159,8 @@ TEST(BatchNorm, precision) {
Place place;
#if defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
#else
return;
#endif
......
......@@ -16,6 +16,7 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -24,13 +25,13 @@ int data_index(std::vector<int> pos, DDimLite dims) {
int d1 = dims[1];
int d2 = dims[2];
int d3 = dims[3];
return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1;
return pos[0] * d1 * d2 * d3 + pos[1] * d2 * d3 + pos[2] * d3 + pos[3];
}
std::vector<int> pos_trans(std::vector<int> in_pos, std::vector<int> axis) {
std::vector<int> out_pos(in_pos.size());
for (int i = 0; i < axis.size(); i++) {
out_pos[axis[i]] = in_pos[i];
out_pos[i] = in_pos[axis[i]];
}
return out_pos;
}
......@@ -42,35 +43,34 @@ class TransposeComputeTester : public arena::TestCase {
std::string input_ = "x";
std::string output_ = "out";
std::string xshape_ = "xshape";
DDim x_dims_;
DDim dims_;
std::vector<int> axis_;
public:
TransposeComputeTester(const Place& place,
const std::string& alias,
DDim x_dims,
DDim dims,
std::vector<int> axis)
: TestCase(place, alias), x_dims_(x_dims), axis_(axis) {}
: TestCase(place, alias), dims_(dims), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
auto* x = scope->FindTensor(input_);
auto x_dims = x->dims();
std::vector<int64_t> out_shape(x_dims.size(), 0);
for (size_t i = 0; i < x_dims.size(); i++) {
out_shape[i] = x_dims[axis_[i]];
std::vector<int64_t> out_shape(dims_.size(), 0);
for (size_t i = 0; i < dims_.size(); i++) {
out_shape[i] = dims_[axis_[i]];
}
out->Resize(out_shape);
auto y_dims = out->dims();
int input_n = x_dims[0];
int input_c = x_dims[1];
int input_h = x_dims[2];
int input_w = x_dims[3];
int input_n = dims_[0];
int input_c = dims_[1];
int input_h = dims_[2];
int input_w = dims_[3];
auto input_data = x->data<float>();
auto output_data = out->mutable_data<float>();
......@@ -81,7 +81,7 @@ class TransposeComputeTester : public arena::TestCase {
for (int w = 0; w < input_w; ++w) {
std::vector<int> in_pos{n, c, h, w};
std::vector<int> out_pos = pos_trans(in_pos, axis_);
int in_index = data_index(in_pos, x_dims);
int in_index = data_index(in_pos, dims_);
int out_index = data_index(out_pos, y_dims);
output_data[out_index] = input_data[in_index];
}
......@@ -91,7 +91,7 @@ class TransposeComputeTester : public arena::TestCase {
if (op_type_ == "transpose2") {
auto* xshape = scope->NewTensor(xshape_);
auto xshape_dims = x_dims.Vectorize();
auto xshape_dims = dims_.Vectorize();
xshape_dims.insert(xshape_dims.begin(), 0);
xshape->Resize(xshape_dims);
}
......@@ -108,11 +108,9 @@ class TransposeComputeTester : public arena::TestCase {
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_, x_dims_, data.data());
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
}
};
......@@ -122,14 +120,16 @@ TEST(Transpose, precision) {
Place place;
#ifdef LITE_WITH_XPU
place = TARGET(kXPU);
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#else
return;
#endif
DDim x_dims{{2, 3, 4, 5}};
// [XPU]: {3, 1, 0, 2} is unsupported
std::vector<std::vector<int>> axes{
{0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {3, 1, 2, 0}};
{0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {3, 1, 2, 0}, {3, 1, 0, 2}};
for (auto axis : axes) {
std::unique_ptr<arena::TestCase> tester(
new TransposeComputeTester(place, "def", x_dims, axis));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册