提交 8e9ad0a9 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] reshape out of fc and add n-D unittest (#2859)

上级 4d8c0863
......@@ -95,7 +95,7 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
CHECK_GE(param.output->dims().size(), 2UL);
m_ = x_dims.Slice(0, param.in_num_col_dims).production();
k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
......
......@@ -34,27 +34,29 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto input_type = kernel->GetInputDeclType("Input");
CHECK(input_type->precision() == PRECISION(kFloat));
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input = scope->FindTensor(input_name);
auto input_dims = input->dims();
CHECK_GE(input_dims.size(), 2UL);
auto w_name = op_info->Input("W").front();
auto w_type = kernel->GetInputDeclType("W");
CHECK(w_type->precision() == PRECISION(kFloat));
CHECK(w_type->layout() == DATALAYOUT(kNCHW));
auto w = scope->FindMutableTensor(w_name);
auto w = scope->FindTensor(w_name);
auto w_dims = w->dims();
CHECK_EQ(w_dims.size(), 2UL);
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindTensor(out_name);
auto out_dims = out->dims();
int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
int m = input_dims.Slice(0, in_num_col_dims).production();
int k = input_dims.Slice(in_num_col_dims, input_dims.size()).production();
int n = w_dims[1];
CHECK_EQ(k * n, w_dims.production());
VLOG(3) << "[NPU] input dims: " << input_dims << " w dims: " << w_dims
<< " m: " << m << " k: " << k << " n: " << n;
// Create input node and reshape it to (m, k, 1, 1)
std::shared_ptr<Node> input_node = nullptr;
......@@ -76,7 +78,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
transpose_w.Resize({n, k, 1, 1});
transpose_w.set_persistable(true);
auto transpose_w_data = transpose_w.mutable_data<float>();
auto w_data = w->mutable_data<float>();
auto w_data = w->data<float>();
for (int i = 0; i < k; i++) {
for (int j = 0; j < n; j++) {
transpose_w_data[j * k + i] = w_data[i * n + j];
......@@ -85,10 +87,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto trans_w_node = graph->Add(w_name, transpose_w);
// FC node
auto fc_node = graph->Add<ge::op::FullConnection>(out_name + "/fc");
auto fc_node = graph->Add<ge::op::FullConnection>(out_name);
auto fc_op = fc_node->data<ge::op::FullConnection>();
fc_op->set_input_x(*reshaped_input_node->data());
fc_op->set_input_w(*trans_w_node->data());
// Add bias node if bias tensor exists
if (HasInputArg(op_info, scope, "Bias")) {
std::shared_ptr<Node> bias_node = nullptr;
......@@ -99,19 +102,23 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias = scope->FindTensor(bias_name);
auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.production(), n);
bias_node = graph->Add(bias_name, *bias, {1, n, 1, 1});
}
fc_op->set_input_b(*bias_node->data());
}
// Reshape output of FC node from (m, n, 1, 1) to (m, n)
// Reshape output of FC node from (m, n, 1, 1) to out_shape
auto reshaped_fc_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_fc_op = reshaped_fc_node->data<ge::op::Reshape>();
reshaped_fc_op->set_input_tensor(*fc_node->data());
reshaped_fc_op->set_attr_shape({m, n});
auto out_shape = out_dims.Vectorize();
reshaped_fc_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
reshaped_fc_op->set_attr_axis(0);
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -37,15 +37,6 @@ class FcOpLite : public OpLite {
bool InferShape() const override;
/*
bool Run() override {
CHECK(kernel_);
kernel_->Run();
return true;
}
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
......@@ -192,7 +192,7 @@ class FcOPTest : public arena::TestCase {
fill_data_rand(bin.data(), -1.f, 1.f, bdims_.production());
SetCommonTensor(input_, dims_, din.data());
SetCommonTensor(weight_, wdims_, win.data());
SetCommonTensor(weight_, wdims_, win.data(), {}, true);
if (padding_weights_) {
std::vector<float> win_padding(wdims_padding_.production());
for (int64_t i = 0; i < wdims_[0]; ++i) {
......@@ -203,12 +203,12 @@ class FcOPTest : public arena::TestCase {
SetCommonTensor(weight_padding_, wdims_padding_, win_padding.data());
}
if (flag_bias) {
SetCommonTensor(bias_, bdims_, bin.data());
SetCommonTensor(bias_, bdims_, bin.data(), {}, true);
}
}
};
void TestFCMain(Place place,
void TestFC2D(Place place,
float abs_error,
bool with_relu = false,
bool padding = false) {
......@@ -242,6 +242,32 @@ void TestFCMain(Place place,
}
}
void TestFCHelper(Place place,
float abs_error,
std::vector<int64_t> xdims,
std::vector<int64_t> wdims,
std::vector<int64_t> bdims,
int in_num_col_dims) {
std::unique_ptr<arena::TestCase> tester(new FcOPTest(place,
"def",
DDim(xdims),
DDim(wdims),
DDim(bdims),
in_num_col_dims,
false,
false));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
void TestFCnD(Place place, float abs_error) {
TestFCHelper(place, abs_error, {2, 3, 4}, {4, 5}, {5}, 2);
TestFCHelper(place, abs_error, {2, 3, 4}, {12, 5}, {5}, 1);
TestFCHelper(place, abs_error, {2, 3, 4, 5}, {5, 6}, {6}, 3);
TestFCHelper(place, abs_error, {2, 3, 4, 5}, {20, 6}, {6}, 2);
TestFCHelper(place, abs_error, {2, 3, 4, 5}, {60, 6}, {6}, 1);
}
TEST(FcOP, precision) {
Place place;
float abs_error = 1e-4;
......@@ -256,7 +282,9 @@ TEST(FcOP, precision) {
#else
return;
#endif
TestFCMain(place, abs_error);
TestFC2D(place, abs_error);
TestFCnD(place, abs_error);
}
#ifdef LITE_WITH_X86
......@@ -264,7 +292,7 @@ TEST(FcOP, padding_and_parallel) {
Place place(TARGET(kX86));
float abs_error = 1e-4;
x86::SetNumThreads(4);
TestFCMain(place, abs_error, true, true);
TestFC2D(place, abs_error, true, true);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册