reshape_op.cc 5.9 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/reshape_op.h"
#include "lite/core/op_registry.h"
17
#include "lite/core/tensor.h"
Y
Yan Chunwei 已提交
18 19 20 21 22 23 24 25 26 27 28 29

namespace paddle {
namespace lite {
namespace operators {

bool ReshapeOp::CheckShape() const {
  CHECK_OR_FALSE(param_.x);
  CHECK_OR_FALSE(param_.output);
  return true;
}

bool ReshapeOp::InferShape() const {
30
  const auto &shape_tensor_vct = param_.shape_tensor_vct;
31
  auto *shape_tensor = param_.shape_tensor;
32
  const auto &shape_vct = param_.shape_vct;
33

34
  std::vector<int> final_shape;
35
  if (shape_tensor_vct.size() > 0) {
36
    final_shape.resize(shape_tensor_vct.size());
37
    for (int i = 0; i < shape_tensor_vct.size(); i++) {
38
      final_shape[i] = shape_tensor_vct[i]->data<int>()[0];
39 40 41 42 43 44 45 46 47 48 49
    }
  } else if (shape_tensor != nullptr) {
    auto *shape_tensor_data = shape_tensor->data<int>();
    final_shape = std::vector<int>(shape_tensor_data,
                                   shape_tensor_data + shape_tensor->numel());
  } else if (!shape_vct.empty()) {
    final_shape = shape_vct;
  } else {
    LOG(FATAL) << "input shape error";
  }

50
  const auto &x_dims = param_.x->dims();
51
  auto output_dims = ValidateShape(final_shape, x_dims);
Y
Yan Chunwei 已提交
52 53 54 55 56 57 58
  param_.output->Resize(output_dims);
  auto out_lod = param_.output->mutable_lod();
  *out_lod = param_.x->lod();
  return true;
}

bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
59 60 61 62
  param_.x =
      scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
  param_.output =
      scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
63

64 65
  // prority: input(ShapeTensor) > input(Shape) > attr(shape)
  if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
66 67 68 69 70 71
    auto args = opdesc.Input("ShapeTensor");
    for (auto arg : args) {
      auto *var = scope->FindVar(arg);
      if (var != nullptr) {
        param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
      }
72
    }
73 74 75 76 77
    CHECK_GT(param_.shape_tensor_vct.size(), 0)
        << "ShapeError: When `shape` in ReshapeOp is a list or tuple "
           "which contains Tensor, the shape's size can't be zero. "
           "But received shape's size is "
        << param_.shape_tensor_vct.size();
78
  }
79
  if (opdesc.HasInput("Shape") && !opdesc.Input("Shape").empty()) {
80 81 82
    auto var = scope->FindVar(opdesc.Input("Shape").front());
    if (var != nullptr) {
      param_.shape_tensor = var->GetMutable<lite::Tensor>();
83
    }
84 85 86 87 88 89
  }
  if (opdesc.HasAttr("shape")) {
    param_.shape_vct = opdesc.GetAttr<std::vector<int>>("shape");
  }
  if (opdesc.HasAttr("inplace")) {
    param_.inplace = opdesc.GetAttr<bool>("inplace");
90
  }
Y
Yan Chunwei 已提交
91 92 93 94 95 96 97 98 99 100 101
  return true;
}

bool Reshape2Op::CheckShape() const {
  ReshapeOp::CheckShape();
  CHECK_OR_FALSE(param_.xshape);
  return true;
}

bool Reshape2Op::InferShape() const {
  ReshapeOp::InferShape();
102
  const auto &x_dims = param_.x->dims();
103 104
  DDim xshape_dims;
  xshape_dims.resize(x_dims.size() + 1);
105
  xshape_dims[0] = 0;
Y
Yan Chunwei 已提交
106 107 108
  for (size_t i = 0; i < x_dims.size(); i++) {
    xshape_dims[i + 1] = x_dims[i];
  }
109
  param_.xshape->Resize(xshape_dims);
110 111
  auto xshape_lod = param_.xshape->mutable_lod();
  *xshape_lod = param_.x->lod();
Y
Yan Chunwei 已提交
112 113 114 115 116 117 118 119 120 121 122
  return true;
}

bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
  ReshapeOp::AttachImpl(opdesc, scope);
  auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
  param_.xshape = xshape_var->GetMutable<lite::Tensor>();
  return true;
}

DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
123 124 125
  const DDim::value_type input_size = input_dims.production();

  // Only one dimension can be set to -1, whose size will be automatically
Y
Yan Chunwei 已提交
126 127 128 129
  // infered.
  const int unk_dim_val = -1;
  const int copy_dim_val = 0;

130 131
  DDim output_dims;
  output_dims.resize(shape.size());
132
  DDim::value_type capacity = 1;
Y
Yan Chunwei 已提交
133 134 135 136 137 138 139
  int unk_dim_idx = -1;
  for (size_t i = 0; i < shape.size(); ++i) {
    if (shape[i] == unk_dim_val) {
      CHECK_EQ(unk_dim_idx, -1)
          << "Only one input dimension of Attr(shape) can be unknown.";
      unk_dim_idx = i;
    } else if (shape[i] == copy_dim_val) {
140
      CHECK_LT(static_cast<int>(i), input_dims.size())
Y
Yan Chunwei 已提交
141 142 143 144 145 146 147
          << "The index of dimension to copy from input shape must be less "
             "than the size of input shape.";
    } else {
      CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not "
                               "be negtive except one unknown dimension.";
    }

148 149 150 151
    DDim::value_type output_dim_i =
        shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_dims[i];
    output_dims[i] = output_dim_i;
    capacity *= output_dim_i;
Y
Yan Chunwei 已提交
152 153 154
  }

  if (unk_dim_idx != -1) {
155
    if (input_dims.CheckPositive()) {
Y
Yan Chunwei 已提交
156 157
      // input_size < 0 and is un-determinate in compile time, skip the check,
      // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
158
      // capacity = -24, input_size = -8, output_dims[0] = 0
Y
Yan Chunwei 已提交
159
      // the following check will fail.
160 161
      output_dims[unk_dim_idx] = -input_size / capacity;
      CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size)
Y
Yan Chunwei 已提交
162 163
          << "Invalid shape is given.";
    } else {
164
      output_dims[unk_dim_idx] = -1;
Y
Yan Chunwei 已提交
165 166 167 168
    }
  } else {
    CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
  }
169
  return output_dims;
Y
Yan Chunwei 已提交
170 171 172 173 174 175 176 177
}

}  // namespace operators
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_OP(reshape, paddle::lite::operators::ReshapeOp);
REGISTER_LITE_OP(reshape2, paddle::lite::operators::Reshape2Op);