diff --git a/lite/kernels/arm/fc_compute.h b/lite/kernels/arm/fc_compute.h index 2e5f2345e824b13d78a1575d3374652b8474c7fd..4f8a82a8689c1f221ee146176ff7074602cad1c9 100644 --- a/lite/kernels/arm/fc_compute.h +++ b/lite/kernels/arm/fc_compute.h @@ -95,7 +95,7 @@ class FcCompute : public KernelLite { CHECK_GE(x_dims.size(), 2UL); CHECK_EQ(w_dims.size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); + CHECK_GE(param.output->dims().size(), 2UL); m_ = x_dims.Slice(0, param.in_num_col_dims).production(); k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); diff --git a/lite/kernels/npu/bridges/fc_op.cc b/lite/kernels/npu/bridges/fc_op.cc index 3d028172154e58c1ed191b4d4eb780e9937458a5..d9d42cd8c73a321449649bca658333fdd5f57325 100644 --- a/lite/kernels/npu/bridges/fc_op.cc +++ b/lite/kernels/npu/bridges/fc_op.cc @@ -34,27 +34,29 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto input_type = kernel->GetInputDeclType("Input"); CHECK(input_type->precision() == PRECISION(kFloat)); CHECK(input_type->layout() == DATALAYOUT(kNCHW)); - auto input = scope->FindMutableTensor(input_name); + auto input = scope->FindTensor(input_name); auto input_dims = input->dims(); - CHECK_GE(input_dims.size(), 2UL); + auto w_name = op_info->Input("W").front(); auto w_type = kernel->GetInputDeclType("W"); CHECK(w_type->precision() == PRECISION(kFloat)); CHECK(w_type->layout() == DATALAYOUT(kNCHW)); - auto w = scope->FindMutableTensor(w_name); + auto w = scope->FindTensor(w_name); auto w_dims = w->dims(); CHECK_EQ(w_dims.size(), 2UL); + auto out_name = op_info->Output("Out").front(); auto out_type = kernel->GetOutputDeclType("Out"); CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->layout() == DATALAYOUT(kNCHW)); + auto out = scope->FindTensor(out_name); + auto out_dims = out->dims(); + int in_num_col_dims = op_info->GetAttr("in_num_col_dims"); int m = input_dims.Slice(0, in_num_col_dims).production(); int k = input_dims.Slice(in_num_col_dims, input_dims.size()).production(); int n = w_dims[1]; CHECK_EQ(k * n, w_dims.production()); - VLOG(3) << "[NPU] input dims: " << input_dims << " w dims: " << w_dims - << " m: " << m << " k: " << k << " n: " << n; // Create input node and reshape it to (m, k, 1, 1) std::shared_ptr input_node = nullptr; @@ -76,7 +78,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { transpose_w.Resize({n, k, 1, 1}); transpose_w.set_persistable(true); auto transpose_w_data = transpose_w.mutable_data(); - auto w_data = w->mutable_data(); + auto w_data = w->data(); for (int i = 0; i < k; i++) { for (int j = 0; j < n; j++) { transpose_w_data[j * k + i] = w_data[i * n + j]; @@ -85,10 +87,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto trans_w_node = graph->Add(w_name, transpose_w); // FC node - auto fc_node = graph->Add(out_name + "/fc"); + auto fc_node = graph->Add(out_name); auto fc_op = fc_node->data(); fc_op->set_input_x(*reshaped_input_node->data()); fc_op->set_input_w(*trans_w_node->data()); + // Add bias node if bias tensor exists if (HasInputArg(op_info, scope, "Bias")) { std::shared_ptr bias_node = nullptr; @@ -99,19 +102,23 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto bias_type = kernel->GetInputDeclType("Bias"); CHECK(bias_type->precision() == PRECISION(kFloat)); CHECK(bias_type->layout() == DATALAYOUT(kNCHW)); - auto bias = scope->FindMutableTensor(bias_name); + auto bias = scope->FindTensor(bias_name); auto bias_dims = bias->dims(); CHECK_EQ(bias_dims.production(), n); bias_node = graph->Add(bias_name, *bias, {1, n, 1, 1}); } fc_op->set_input_b(*bias_node->data()); } - // Reshape output of FC node from (m, n, 1, 1) to (m, n) + + // Reshape output of FC node from (m, n, 1, 1) to out_shape auto reshaped_fc_node = graph->Add(out_name); auto reshaped_fc_op = reshaped_fc_node->data(); reshaped_fc_op->set_input_tensor(*fc_node->data()); - reshaped_fc_op->set_attr_shape({m, n}); + auto out_shape = out_dims.Vectorize(); + reshaped_fc_op->set_attr_shape( + ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); reshaped_fc_op->set_attr_axis(0); + return REBUILD_WHEN_SHAPE_CHANGED; } diff --git a/lite/operators/fc_op.h b/lite/operators/fc_op.h index 3cddde38b291f189649175a43c994d4fcfcabb9b..ec449cd4bdc33f191c33fc04f215ad672b283215 100644 --- a/lite/operators/fc_op.h +++ b/lite/operators/fc_op.h @@ -37,15 +37,6 @@ class FcOpLite : public OpLite { bool InferShape() const override; - /* - bool Run() override { - CHECK(kernel_); - kernel_->Run(); - return true; - } - */ - - // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index a9aac49eb588e0a332b28b7ea7d1e320c2f52413..6d879385a27e834b3fa27835ee94edc599f5564c 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -192,7 +192,7 @@ class FcOPTest : public arena::TestCase { fill_data_rand(bin.data(), -1.f, 1.f, bdims_.production()); SetCommonTensor(input_, dims_, din.data()); - SetCommonTensor(weight_, wdims_, win.data()); + SetCommonTensor(weight_, wdims_, win.data(), {}, true); if (padding_weights_) { std::vector win_padding(wdims_padding_.production()); for (int64_t i = 0; i < wdims_[0]; ++i) { @@ -203,15 +203,15 @@ class FcOPTest : public arena::TestCase { SetCommonTensor(weight_padding_, wdims_padding_, win_padding.data()); } if (flag_bias) { - SetCommonTensor(bias_, bdims_, bin.data()); + SetCommonTensor(bias_, bdims_, bin.data(), {}, true); } } }; -void TestFCMain(Place place, - float abs_error, - bool with_relu = false, - bool padding = false) { +void TestFC2D(Place place, + float abs_error, + bool with_relu = false, + bool padding = false) { for (auto& m : {1, 3, 16}) { for (auto& n : {1, 4, 16, 128, 256, 1024}) { for (auto& k : {1, 16, 128, 1024}) { @@ -242,6 +242,32 @@ void TestFCMain(Place place, } } +void TestFCHelper(Place place, + float abs_error, + std::vector xdims, + std::vector wdims, + std::vector bdims, + int in_num_col_dims) { + std::unique_ptr tester(new FcOPTest(place, + "def", + DDim(xdims), + DDim(wdims), + DDim(bdims), + in_num_col_dims, + false, + false)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); +} + +void TestFCnD(Place place, float abs_error) { + TestFCHelper(place, abs_error, {2, 3, 4}, {4, 5}, {5}, 2); + TestFCHelper(place, abs_error, {2, 3, 4}, {12, 5}, {5}, 1); + TestFCHelper(place, abs_error, {2, 3, 4, 5}, {5, 6}, {6}, 3); + TestFCHelper(place, abs_error, {2, 3, 4, 5}, {20, 6}, {6}, 2); + TestFCHelper(place, abs_error, {2, 3, 4, 5}, {60, 6}, {6}, 1); +} + TEST(FcOP, precision) { Place place; float abs_error = 1e-4; @@ -256,7 +282,9 @@ TEST(FcOP, precision) { #else return; #endif - TestFCMain(place, abs_error); + + TestFC2D(place, abs_error); + TestFCnD(place, abs_error); } #ifdef LITE_WITH_X86 @@ -264,7 +292,7 @@ TEST(FcOP, padding_and_parallel) { Place place(TARGET(kX86)); float abs_error = 1e-4; x86::SetNumThreads(4); - TestFCMain(place, abs_error, true, true); + TestFC2D(place, abs_error, true, true); } #endif