reshape_op.cc 4.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/reshape_op.h"
16
#include "lite/backends/npu/builder.h"
Z
zhupengyang 已提交
17
#include "lite/kernels/npu/bridges/registry.h"
Y
Yan Chunwei 已提交
18 19 20

namespace paddle {
namespace lite {
Z
zhupengyang 已提交
21
namespace kernels {
Y
Yan Chunwei 已提交
22
namespace npu {
Z
zhupengyang 已提交
23
namespace bridges {
Y
Yan Chunwei 已提交
24 25 26 27 28 29

node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
                               const node_map_type& inputs_map) {
  auto scope = reshape_op->scope();
  auto op_info = reshape_op->op_info();
  auto op_type = op_info->Type();
30
  auto unique_op_type = lite::npu::UniqueName(op_type);
Y
Yan Chunwei 已提交
31 32 33 34 35 36 37 38 39 40 41
  LOG(INFO) << "Converting " + op_type + "...";

  // get input, output and op attributes
  auto x_var_name = op_info->Input("X").front();
  auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
  auto x_dims = x->dims();

  // create reshape node and set input node from inputs_map
  auto reshape_node = std::make_shared<ge::op::Reshape>(unique_op_type);
  CHECK(inputs_map.count(x_var_name));
  reshape_node->set_input_tensor(*inputs_map.at(x_var_name));
42
  lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
Y
Yan Chunwei 已提交
43 44

  // read shape from actual shape tensor as input "w" if 'Shape' is found
45
  if (lite::npu::HasInputArg(op_info, scope, "Shape")) {
Y
Yan Chunwei 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    auto actual_shape_var_name = op_info->Input("Shape").front();
    if (!inputs_map.count(actual_shape_var_name)) {
      auto actual_shape =
          scope->FindVar(actual_shape_var_name)->GetMutable<lite::Tensor>();
      auto actual_shape_dims = actual_shape->dims();
      auto actual_shape_data = actual_shape->mutable_data<int>();
      auto shape =
          std::vector<int>(actual_shape_data,
                           actual_shape_data + actual_shape_dims.production());
      auto out_dims = operators::ValidateShape(shape, x_dims);
      auto out_shape = out_dims.Vectorize();
      if (out_shape.size() > 4) {
        LOG(WARNING)
            << "NPU DDK only supports less than 4 dimensions, but Shape has "
            << out_shape.size();
      }
      auto actual_shape_const_node =
          std::make_shared<ge::op::Const>(actual_shape_var_name);
64 65 66
      actual_shape_const_node->set_attr_value(
          lite::npu::CreateTensorAndFillData(
              std::vector<int>(out_shape.begin(), out_shape.end())));
Y
Yan Chunwei 已提交
67
      reshape_node->set_input_w(*actual_shape_const_node);
68
      lite::npu::OpList::Global().add(actual_shape_const_node);
Y
Yan Chunwei 已提交
69 70
    } else {
      reshape_node->set_input_w(*inputs_map.at(actual_shape_var_name));
71
      lite::npu::OpList::Global().add(inputs_map.at(actual_shape_var_name));
Y
Yan Chunwei 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84
    }
  } else {
    auto shape = op_info->GetAttr<std::vector<int>>("shape");
    auto out_dims = operators::ValidateShape(shape, x_dims);
    auto out_shape = out_dims.Vectorize();
    if (out_shape.size() > 4) {
      LOG(WARNING)
          << "NPU DDK only supports less than 4 dimensions, but shape has "
          << out_shape.size();
    }
    reshape_node->set_attr_shape(
        ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
  }
85
  lite::npu::OpList::Global().add(reshape_node);
Y
Yan Chunwei 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

  node_map_type outputs_map;
  outputs_map[op_info->Output("Out").front()] = reshape_node;
  if (op_type == "reshape2") {
    // append an extra reshape node to calc XShape
    std::vector<int64_t> xshape_dims(x_dims.size() + 1, 1);
    for (size_t i = 0; i < x_dims.size(); i++) {
      xshape_dims[i + 1] = x_dims[i];
    }
    if (xshape_dims.size() > 4) {
      LOG(WARNING)
          << "NPU DDK only supports less than 4 dimensions, but XShape has "
          << xshape_dims.size();
    }
    auto xshape_node =
        std::make_shared<ge::op::Reshape>(unique_op_type + "/xshape");
    xshape_node->set_input_tensor(*inputs_map.at(x_var_name));
    xshape_node->set_attr_shape(
        ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
105
    lite::npu::OpList::Global().add(xshape_node);
Y
Yan Chunwei 已提交
106 107 108 109 110
    outputs_map[op_info->Output("XShape").front()] = xshape_node;
  }
  return outputs_map;
}

Z
zhupengyang 已提交
111
}  // namespace bridges
Y
Yan Chunwei 已提交
112
}  // namespace npu
Z
zhupengyang 已提交
113
}  // namespace kernels
Y
Yan Chunwei 已提交
114 115 116
}  // namespace lite
}  // namespace paddle

Z
zhupengyang 已提交
117 118 119 120
REGISTER_NPU_BRIDGE(reshape,
                    paddle::lite::kernels::npu::bridges::ReshapeConverter);
REGISTER_NPU_BRIDGE(reshape2,
                    paddle::lite::kernels::npu::bridges::ReshapeConverter);