reshape_op.cc 6.2 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/reshape_op.h"
#include "lite/core/op_registry.h"
17
#include "lite/core/tensor.h"
Y
Yan Chunwei 已提交
18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace lite {
namespace operators {

bool ReshapeOp::CheckShape() const {
  CHECK_OR_FALSE(param_.x);
  CHECK_OR_FALSE(param_.output);
  return true;
}

29
bool ReshapeOp::InferShapeImpl() const {
30
  const auto &shape_tensor_vct = param_.shape_tensor_vct;
31
  auto *shape_tensor = param_.shape_tensor;
32
  const auto &shape_vct = param_.shape_vct;
33

34
  std::vector<int> final_shape;
35
  if (shape_tensor_vct.size() > 0) {
36 37 38
    final_shape.resize(shape_tensor_vct.size());
    for (size_t i = 0; i < shape_tensor_vct.size(); i++) {
      final_shape[i] = shape_tensor_vct[i]->data<int>()[0];
39 40 41 42 43 44 45 46 47 48 49
    }
  } else if (shape_tensor != nullptr) {
    auto *shape_tensor_data = shape_tensor->data<int>();
    final_shape = std::vector<int>(shape_tensor_data,
                                   shape_tensor_data + shape_tensor->numel());
  } else if (!shape_vct.empty()) {
    final_shape = shape_vct;
  } else {
    LOG(FATAL) << "input shape error";
  }

50
  const auto &x_dims = param_.x->dims();
51
  auto output_dims = ValidateShape(final_shape, x_dims);
Y
Yan Chunwei 已提交
52 53 54 55 56 57 58
  param_.output->Resize(output_dims);
  auto out_lod = param_.output->mutable_lod();
  *out_lod = param_.x->lod();
  return true;
}

bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
59 60 61 62
  param_.x =
      scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
  param_.output =
      scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
63

64 65
  // prority: input(ShapeTensor) > input(Shape) > attr(shape)
  if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
66 67 68 69 70 71
    auto args = opdesc.Input("ShapeTensor");
    for (auto arg : args) {
      auto *var = scope->FindVar(arg);
      if (var != nullptr) {
        param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
      }
72
    }
73 74 75 76 77
    CHECK_GT(param_.shape_tensor_vct.size(), 0)
        << "ShapeError: When `shape` in ReshapeOp is a list or tuple "
           "which contains Tensor, the shape's size can't be zero. "
           "But received shape's size is "
        << param_.shape_tensor_vct.size();
78
  }
79
  if (opdesc.HasInput("Shape") && !opdesc.Input("Shape").empty()) {
80 81 82
    auto var = scope->FindVar(opdesc.Input("Shape").front());
    if (var != nullptr) {
      param_.shape_tensor = var->GetMutable<lite::Tensor>();
83
    }
84 85 86 87 88 89
  }
  if (opdesc.HasAttr("shape")) {
    param_.shape_vct = opdesc.GetAttr<std::vector<int>>("shape");
  }
  if (opdesc.HasAttr("inplace")) {
    param_.inplace = opdesc.GetAttr<bool>("inplace");
90
  }
Y
Yan Chunwei 已提交
91 92 93 94 95 96 97 98 99
  return true;
}

bool Reshape2Op::CheckShape() const {
  ReshapeOp::CheckShape();
  CHECK_OR_FALSE(param_.xshape);
  return true;
}

100 101
bool Reshape2Op::InferShapeImpl() const {
  ReshapeOp::InferShapeImpl();
102 103 104
  const auto &x_dims = param_.x->dims();
  std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1);
  xshape_dims[0] = 0;
Y
Yan Chunwei 已提交
105 106 107
  for (size_t i = 0; i < x_dims.size(); i++) {
    xshape_dims[i + 1] = x_dims[i];
  }
108
  param_.xshape->Resize(xshape_dims);
109 110
  auto xshape_lod = param_.xshape->mutable_lod();
  *xshape_lod = param_.x->lod();
Y
Yan Chunwei 已提交
111 112 113 114 115 116 117 118 119 120
  return true;
}

bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
  ReshapeOp::AttachImpl(opdesc, scope);
  auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
  param_.xshape = xshape_var->GetMutable<lite::Tensor>();
  return true;
}

121 122 123 124 125 126 127 128 129 130 131 132 133
static bool CheckPositive(const DDim &dims) {
  for (size_t i = 0; i < dims.size(); ++i) {
    if (dims[i] <= 0) {
      return false;
    }
  }
  return true;
}

std::vector<DDim::value_type> ValidateShape(const std::vector<int> &shape,
                                            const DDim &input_dims) {
  const DDim::value_type input_size = input_dims.production();

Y
Yan Chunwei 已提交
134 135 136 137 138
  // only one dimension can be set to -1, whose size will be automatically
  // infered.
  const int unk_dim_val = -1;
  const int copy_dim_val = 0;

139 140
  std::vector<DDim::value_type> output_dims(shape.size());
  DDim::value_type capacity = 1;
Y
Yan Chunwei 已提交
141 142 143 144 145 146 147
  int unk_dim_idx = -1;
  for (size_t i = 0; i < shape.size(); ++i) {
    if (shape[i] == unk_dim_val) {
      CHECK_EQ(unk_dim_idx, -1)
          << "Only one input dimension of Attr(shape) can be unknown.";
      unk_dim_idx = i;
    } else if (shape[i] == copy_dim_val) {
148
      CHECK_LT(static_cast<int>(i), input_dims.size())
Y
Yan Chunwei 已提交
149 150 151 152 153 154 155
          << "The index of dimension to copy from input shape must be less "
             "than the size of input shape.";
    } else {
      CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not "
                               "be negtive except one unknown dimension.";
    }

156 157 158 159
    DDim::value_type output_dim_i =
        shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_dims[i];
    output_dims[i] = output_dim_i;
    capacity *= output_dim_i;
Y
Yan Chunwei 已提交
160 161 162
  }

  if (unk_dim_idx != -1) {
163
    if (CheckPositive(input_dims)) {
Y
Yan Chunwei 已提交
164 165 166 167
      // input_size < 0 and is un-determinate in compile time, skip the check,
      // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
      // capacity = -24, input_size = -8, output_shape[0] = 0
      // the following check will fail.
168 169
      output_dims[unk_dim_idx] = -input_size / capacity;
      CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size)
Y
Yan Chunwei 已提交
170 171
          << "Invalid shape is given.";
    } else {
172
      output_dims[unk_dim_idx] = -1;
Y
Yan Chunwei 已提交
173 174 175 176
    }
  } else {
    CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
  }
177
  return output_dims;
Y
Yan Chunwei 已提交
178 179 180 181 182 183 184 185
}

}  // namespace operators
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_OP(reshape, paddle::lite::operators::ReshapeOp);
REGISTER_LITE_OP(reshape2, paddle::lite::operators::Reshape2Op);