reshape_op.cc 5.0 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/reshape_op.h"
16
#include "lite/backends/npu/builder.h"
Z
zhupengyang 已提交
17
#include "lite/kernels/npu/bridges/registry.h"
Y
Yan Chunwei 已提交
18 19 20

namespace paddle {
namespace lite {
Z
zhupengyang 已提交
21
namespace kernels {
Y
Yan Chunwei 已提交
22
namespace npu {
Z
zhupengyang 已提交
23
namespace bridges {
Y
Yan Chunwei 已提交
24 25 26 27 28 29

node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
                               const node_map_type& inputs_map) {
  auto scope = reshape_op->scope();
  auto op_info = reshape_op->op_info();
  auto op_type = op_info->Type();
30
  auto unique_op_type = lite::npu::UniqueName(op_type);
31
  LOG(INFO) << "[NPU] Converting " + op_type + "...";
Y
Yan Chunwei 已提交
32 33 34 35 36 37 38 39 40 41

  // get input, output and op attributes
  auto x_var_name = op_info->Input("X").front();
  auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
  auto x_dims = x->dims();

  // create reshape node and set input node from inputs_map
  auto reshape_node = std::make_shared<ge::op::Reshape>(unique_op_type);
  CHECK(inputs_map.count(x_var_name));
  reshape_node->set_input_tensor(*inputs_map.at(x_var_name));
42
  lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
Y
Yan Chunwei 已提交
43

44 45 46 47
  // read shape from "ShapeTensor"(input), or "Shape"(input), or "shape"(attr)
  if (lite::npu::HasInputArg(op_info, scope, "ShapeTensor")) {
    LOG(FATAL) << "[NPU] not support \"Shape\" from more than one Tensor.";
  } else if (lite::npu::HasInputArg(op_info, scope, "Shape")) {
Y
Yan Chunwei 已提交
48 49 50 51 52 53 54 55 56 57 58 59
    auto actual_shape_var_name = op_info->Input("Shape").front();
    if (!inputs_map.count(actual_shape_var_name)) {
      auto actual_shape =
          scope->FindVar(actual_shape_var_name)->GetMutable<lite::Tensor>();
      auto actual_shape_dims = actual_shape->dims();
      auto actual_shape_data = actual_shape->mutable_data<int>();
      auto shape =
          std::vector<int>(actual_shape_data,
                           actual_shape_data + actual_shape_dims.production());
      auto out_dims = operators::ValidateShape(shape, x_dims);
      auto out_shape = out_dims.Vectorize();
      if (out_shape.size() > 4) {
60 61 62
        LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
                        "but Shape has "
                     << out_shape.size();
Y
Yan Chunwei 已提交
63 64 65
      }
      auto actual_shape_const_node =
          std::make_shared<ge::op::Const>(actual_shape_var_name);
66 67 68
      actual_shape_const_node->set_attr_value(
          lite::npu::CreateTensorAndFillData(
              std::vector<int>(out_shape.begin(), out_shape.end())));
Y
Yan Chunwei 已提交
69
      reshape_node->set_input_w(*actual_shape_const_node);
70
      lite::npu::OpList::Global().add(actual_shape_const_node);
Y
Yan Chunwei 已提交
71 72
    } else {
      reshape_node->set_input_w(*inputs_map.at(actual_shape_var_name));
73
      lite::npu::OpList::Global().add(inputs_map.at(actual_shape_var_name));
Y
Yan Chunwei 已提交
74 75 76 77 78 79
    }
  } else {
    auto shape = op_info->GetAttr<std::vector<int>>("shape");
    auto out_dims = operators::ValidateShape(shape, x_dims);
    auto out_shape = out_dims.Vectorize();
    if (out_shape.size() > 4) {
80 81 82
      LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
                      "but shape has "
                   << out_shape.size();
Y
Yan Chunwei 已提交
83 84 85 86
    }
    reshape_node->set_attr_shape(
        ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
  }
87
  lite::npu::OpList::Global().add(reshape_node);
Y
Yan Chunwei 已提交
88 89 90 91 92 93 94 95 96 97

  node_map_type outputs_map;
  outputs_map[op_info->Output("Out").front()] = reshape_node;
  if (op_type == "reshape2") {
    // append an extra reshape node to calc XShape
    std::vector<int64_t> xshape_dims(x_dims.size() + 1, 1);
    for (size_t i = 0; i < x_dims.size(); i++) {
      xshape_dims[i + 1] = x_dims[i];
    }
    if (xshape_dims.size() > 4) {
98 99 100
      LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
                      "but XShape has "
                   << xshape_dims.size();
Y
Yan Chunwei 已提交
101 102 103 104 105 106
    }
    auto xshape_node =
        std::make_shared<ge::op::Reshape>(unique_op_type + "/xshape");
    xshape_node->set_input_tensor(*inputs_map.at(x_var_name));
    xshape_node->set_attr_shape(
        ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
107
    lite::npu::OpList::Global().add(xshape_node);
Y
Yan Chunwei 已提交
108 109 110 111 112
    outputs_map[op_info->Output("XShape").front()] = xshape_node;
  }
  return outputs_map;
}

Z
zhupengyang 已提交
113
}  // namespace bridges
Y
Yan Chunwei 已提交
114
}  // namespace npu
Z
zhupengyang 已提交
115
}  // namespace kernels
Y
Yan Chunwei 已提交
116 117 118
}  // namespace lite
}  // namespace paddle

Z
zhupengyang 已提交
119 120 121 122
REGISTER_NPU_BRIDGE(reshape,
                    paddle::lite::kernels::npu::bridges::ReshapeConverter);
REGISTER_NPU_BRIDGE(reshape2,
                    paddle::lite::kernels::npu::bridges::ReshapeConverter);