fc_op.cc 4.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "lite/kernels/npu/bridges/graph.h"
Z
zhupengyang 已提交
16
#include "lite/kernels/npu/bridges/registry.h"
17
#include "lite/kernels/npu/bridges/utility.h"
Y
Yan Chunwei 已提交
18 19 20

namespace paddle {
namespace lite {
21
namespace subgraph {
Y
Yan Chunwei 已提交
22 23
namespace npu {

24
int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
25 26 27 28
  CHECK(ctx != nullptr);
  CHECK(op != nullptr);
  auto graph = static_cast<Graph*>(ctx);
  auto op_info = op->op_info();
29
  auto op_type = op_info->Type();
30 31
  auto scope = op->scope();
  VLOG(3) << "[NPU] Converting " + op_type + "...";
Y
Yan Chunwei 已提交
32

33 34 35 36
  auto input_name = op_info->Input("Input").front();
  auto input_type = kernel->GetInputDeclType("Input");
  CHECK(input_type->precision() == PRECISION(kFloat));
  CHECK(input_type->layout() == DATALAYOUT(kNCHW));
37
  auto input = scope->FindTensor(input_name);
38
  auto input_dims = input->dims();
39

40 41 42 43
  auto w_name = op_info->Input("W").front();
  auto w_type = kernel->GetInputDeclType("W");
  CHECK(w_type->precision() == PRECISION(kFloat));
  CHECK(w_type->layout() == DATALAYOUT(kNCHW));
44
  auto w = scope->FindTensor(w_name);
45
  auto w_dims = w->dims();
Y
Yan Chunwei 已提交
46
  CHECK_EQ(w_dims.size(), 2UL);
47

48 49 50 51
  auto out_name = op_info->Output("Out").front();
  auto out_type = kernel->GetOutputDeclType("Out");
  CHECK(out_type->precision() == PRECISION(kFloat));
  CHECK(out_type->layout() == DATALAYOUT(kNCHW));
52 53 54
  auto out = scope->FindTensor(out_name);
  auto out_dims = out->dims();

55 56 57
  int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
  int m = input_dims.Slice(0, in_num_col_dims).production();
  int k = input_dims.Slice(in_num_col_dims, input_dims.size()).production();
Y
Yan Chunwei 已提交
58
  int n = w_dims[1];
59
  CHECK_EQ(k * n, w_dims.production());
Y
Yan Chunwei 已提交
60

61
  // Create input node and reshape it to (m, k, 1, 1)
62 63 64
  std::shared_ptr<Node> input_node = nullptr;
  if (graph->Has(input_name)) {
    input_node = graph->Get(input_name);
65
  } else {
66
    input_node = graph->Add(input_name, *input);
67 68
  }
  auto reshaped_input_node =
69 70 71 72 73
      graph->Add<ge::op::Reshape>(input_name + "/reshape");
  auto reshaped_input_op = reshaped_input_node->data<ge::op::Reshape>();
  reshaped_input_op->set_input_tensor(*input_node->data());
  reshaped_input_op->set_attr_shape({m, k, 1, 1});
  reshaped_input_op->set_attr_axis(0);
74

75
  // Create w const node, set its shape to (n, k, 1, 1) and fill with
76
  // the transposed w tensor
77 78
  Tensor transpose_w;
  transpose_w.Resize({n, k, 1, 1});
79
  transpose_w.set_persistable(true);
80
  auto transpose_w_data = transpose_w.mutable_data<float>();
81
  auto w_data = w->data<float>();
82 83
  for (int i = 0; i < k; i++) {
    for (int j = 0; j < n; j++) {
84
      transpose_w_data[j * k + i] = w_data[i * n + j];
85 86
    }
  }
87
  auto trans_w_node = graph->Add(w_name, transpose_w);
Y
Yan Chunwei 已提交
88

89
  // FC node
90
  auto fc_node = graph->Add<ge::op::FullConnection>(out_name);
91 92 93
  auto fc_op = fc_node->data<ge::op::FullConnection>();
  fc_op->set_input_x(*reshaped_input_node->data());
  fc_op->set_input_w(*trans_w_node->data());
94

95 96
  // Add bias node if bias tensor exists
  if (HasInputArg(op_info, scope, "Bias")) {
97
    std::shared_ptr<Node> bias_node = nullptr;
98
    auto bias_name = op_info->Input("Bias").front();
99 100 101 102 103 104
    if (graph->Has(bias_name)) {
      bias_node = graph->Get(bias_name);
    } else {
      auto bias_type = kernel->GetInputDeclType("Bias");
      CHECK(bias_type->precision() == PRECISION(kFloat));
      CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
105
      auto bias = scope->FindTensor(bias_name);
106 107 108 109 110
      auto bias_dims = bias->dims();
      CHECK_EQ(bias_dims.production(), n);
      bias_node = graph->Add(bias_name, *bias, {1, n, 1, 1});
    }
    fc_op->set_input_b(*bias_node->data());
Y
Yan Chunwei 已提交
111
  }
112 113

  // Reshape output of FC node from (m, n, 1, 1) to out_shape
114 115 116
  auto reshaped_fc_node = graph->Add<ge::op::Reshape>(out_name);
  auto reshaped_fc_op = reshaped_fc_node->data<ge::op::Reshape>();
  reshaped_fc_op->set_input_tensor(*fc_node->data());
117 118 119
  auto out_shape = out_dims.Vectorize();
  reshaped_fc_op->set_attr_shape(
      ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
120
  reshaped_fc_op->set_attr_axis(0);
121

122
  return REBUILD_WHEN_SHAPE_CHANGED;
Y
Yan Chunwei 已提交
123 124 125
}

}  // namespace npu
126
}  // namespace subgraph
Y
Yan Chunwei 已提交
127 128 129
}  // namespace lite
}  // namespace paddle

130
REGISTER_SUBGRAPH_BRIDGE(fc, kNPU, paddle::lite::subgraph::npu::FCConverter);