mul_op.cc 2.9 KB
Newer Older
Z
zhupengyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
Z
zhupengyang 已提交
18 19 20

namespace paddle {
namespace lite {
21
namespace subgraph {
Z
zhupengyang 已提交
22 23
namespace xpu {

24 25 26 27
int MulConverter(void* ctx, OpLite* op) {
  CHECK(ctx != nullptr);
  CHECK(op != nullptr);
  auto graph = static_cast<Graph*>(ctx);
Z
zhupengyang 已提交
28 29
  auto op_info = op->op_info();
  auto op_type = op_info->Type();
30 31
  auto scope = op->scope();
  VLOG(3) << "[XPU] Converting " + op_type + "...";
Z
zhupengyang 已提交
32

33
  // Get input, and attributes
Z
zhupengyang 已提交
34 35
  auto x_var_name = op_info->Input("X").front();
  auto y_var_name = op_info->Input("Y").front();
36 37 38
  auto out_var_name = op_info->Output("Out").front();
  auto y = scope->FindMutableTensor(y_var_name);
  auto y_dims = y->dims();
Z
zhupengyang 已提交
39 40 41 42 43 44 45
  CHECK_EQ(y_dims.size(), 2) << "xpu now only support y_dims.size() == 2";

  auto x_num_col_dims = op_info->GetAttr<int>("x_num_col_dims");
  CHECK_EQ(x_num_col_dims, 1) << "xpu now only support x_num_col_dims == 1";
  auto y_num_col_dims = op_info->GetAttr<int>("x_num_col_dims");
  CHECK_EQ(y_num_col_dims, 1) << "xpu now only support y_num_col_dims == 1";

46 47 48 49
  // Flatten x node
  auto x_node = graph->AddNode(
      x_var_name + "/flatten",
      graph->builder_.CreateBatchFlatten(*graph->GetNode(x_var_name)));
Z
zhupengyang 已提交
50

51 52 53 54 55 56 57 58 59 60
  // Transpose y data and create y node
  Tensor transpose_y;
  DDim transpose_y_dims(std::vector<int64_t>{y_dims[1], y_dims[0]});
  transpose_y.Resize(transpose_y_dims);
  auto transpose_y_data = transpose_y.mutable_data<float>();
  auto y_data = y->mutable_data<float>();
  for (int i = 0; i < transpose_y_dims[0]; i++) {
    for (int j = 0; j < transpose_y_dims[1]; j++) {
      transpose_y_data[i * transpose_y_dims[1] + j] =
          y_data[j * transpose_y_dims[0] + i];
Z
zhupengyang 已提交
61 62
    }
  }
63
  auto y_const_node = graph->AddNode(y_var_name + "/transpose", transpose_y);
Z
zhupengyang 已提交
64

65 66 67 68 69 70 71 72
  // Create mul node and set params from op
  graph->AddNode(
      out_var_name,
      graph->builder_.CreateDense(*x_node,
                                  static_cast<int>(y_dims[1]),
                                  ::xtcl::NullValue<::xtcl::DataType>(),
                                  *y_const_node));
  return REBUILD_WHEN_SHAPE_CHANGED;
Z
zhupengyang 已提交
73 74 75
}

}  // namespace xpu
76
}  // namespace subgraph
Z
zhupengyang 已提交
77 78 79
}  // namespace lite
}  // namespace paddle

80
REGISTER_SUBGRAPH_BRIDGE(XPU, mul, paddle::lite::subgraph::xpu::MulConverter);