reshape_op.cc 6.1 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/reshape_op.h"
#include "lite/core/op_registry.h"
17
#include "lite/core/tensor.h"
Y
Yan Chunwei 已提交
18 19 20 21 22 23 24 25 26 27 28 29

namespace paddle {
namespace lite {
namespace operators {

bool ReshapeOp::CheckShape() const {
  CHECK_OR_FALSE(param_.x);
  CHECK_OR_FALSE(param_.output);
  return true;
}

bool ReshapeOp::InferShape() const {
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
  auto shape_tensor_vct = param_.shape_tensor_vct;
  auto *shape_tensor = param_.shape_tensor;
  auto shape_vct = param_.shape_vct;
  std::vector<int> final_shape;

  if (shape_tensor_vct.size() > 0) {
    for (int i = 0; i < shape_tensor_vct.size(); i++) {
      final_shape.push_back(shape_tensor_vct[i]->data<int>()[0]);
    }
  } else if (shape_tensor != nullptr) {
    auto *shape_tensor_data = shape_tensor->data<int>();
    final_shape = std::vector<int>(shape_tensor_data,
                                   shape_tensor_data + shape_tensor->numel());
  } else if (!shape_vct.empty()) {
    final_shape = shape_vct;
  } else {
    LOG(FATAL) << "input shape error";
  }

Y
Yan Chunwei 已提交
49
  auto x_dims = param_.x->dims();
50
  auto output_dims = ValidateShape(final_shape, x_dims);
Y
Yan Chunwei 已提交
51 52 53 54 55 56 57
  param_.output->Resize(output_dims);
  auto out_lod = param_.output->mutable_lod();
  *out_lod = param_.x->lod();
  return true;
}

bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
58 59 60 61
  param_.x =
      scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
  param_.output =
      scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
62

63 64
  // prority: input(ShapeTensor) > input(Shape) > attr(shape)
  if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
65 66 67 68 69 70
    auto args = opdesc.Input("ShapeTensor");
    for (auto arg : args) {
      auto *var = scope->FindVar(arg);
      if (var != nullptr) {
        param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
      }
71
    }
72 73 74 75 76
    CHECK_GT(param_.shape_tensor_vct.size(), 0)
        << "ShapeError: When `shape` in ReshapeOp is a list or tuple "
           "which contains Tensor, the shape's size can't be zero. "
           "But received shape's size is "
        << param_.shape_tensor_vct.size();
77
  }
78
  if (opdesc.HasInput("Shape") && !opdesc.Input("Shape").empty()) {
79 80 81
    auto var = scope->FindVar(opdesc.Input("Shape").front());
    if (var != nullptr) {
      param_.shape_tensor = var->GetMutable<lite::Tensor>();
82
    }
83 84 85 86 87 88
  }
  if (opdesc.HasAttr("shape")) {
    param_.shape_vct = opdesc.GetAttr<std::vector<int>>("shape");
  }
  if (opdesc.HasAttr("inplace")) {
    param_.inplace = opdesc.GetAttr<bool>("inplace");
89
  }
Y
Yan Chunwei 已提交
90 91 92 93 94 95 96 97 98 99 100 101
  return true;
}

bool Reshape2Op::CheckShape() const {
  ReshapeOp::CheckShape();
  CHECK_OR_FALSE(param_.xshape);
  return true;
}

bool Reshape2Op::InferShape() const {
  ReshapeOp::InferShape();
  auto x_dims = param_.x->dims();
102
  std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
Y
Yan Chunwei 已提交
103 104 105
  for (size_t i = 0; i < x_dims.size(); i++) {
    xshape_dims[i + 1] = x_dims[i];
  }
106
  param_.xshape->Resize(xshape_dims);
107 108
  auto xshape_lod = param_.xshape->mutable_lod();
  *xshape_lod = param_.x->lod();
Y
Yan Chunwei 已提交
109 110 111 112 113 114 115 116 117 118 119
  return true;
}

bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
  ReshapeOp::AttachImpl(opdesc, scope);
  auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
  param_.xshape = xshape_var->GetMutable<lite::Tensor>();
  return true;
}

DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
120
  const lite::DDim::value_type input_size = input_dims.production();
Y
Yan Chunwei 已提交
121
  auto input_shape = input_dims.Vectorize();
122 123 124 125
  bool all_positive = std::all_of(
      input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) {
        return i > 0;
      });
Y
Yan Chunwei 已提交
126 127 128 129 130
  // only one dimension can be set to -1, whose size will be automatically
  // infered.
  const int unk_dim_val = -1;
  const int copy_dim_val = 0;

131 132
  std::vector<lite::DDim::value_type> output_shape(shape.size(), 0);
  lite::DDim::value_type capacity = 1;
Y
Yan Chunwei 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  int unk_dim_idx = -1;
  for (size_t i = 0; i < shape.size(); ++i) {
    if (shape[i] == unk_dim_val) {
      CHECK_EQ(unk_dim_idx, -1)
          << "Only one input dimension of Attr(shape) can be unknown.";
      unk_dim_idx = i;
    } else if (shape[i] == copy_dim_val) {
      CHECK_LT(static_cast<int>(i), input_shape.size())
          << "The index of dimension to copy from input shape must be less "
             "than the size of input shape.";
    } else {
      CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not "
                               "be negtive except one unknown dimension.";
    }

148 149 150 151
    capacity *= (shape[i] ? static_cast<lite::DDim::value_type>(shape[i])
                          : input_shape[i]);
    output_shape[i] = (shape[i] ? static_cast<lite::DDim::value_type>(shape[i])
                                : input_shape[i]);
Y
Yan Chunwei 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  }

  if (unk_dim_idx != -1) {
    if (all_positive) {
      // input_size < 0 and is un-determinate in compile time, skip the check,
      // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
      // capacity = -24, input_size = -8, output_shape[0] = 0
      // the following check will fail.
      output_shape[unk_dim_idx] = -input_size / capacity;
      CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size)
          << "Invalid shape is given.";
    } else {
      output_shape[unk_dim_idx] = -1;
    }
  } else {
    CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
  }
169
  return lite::DDim(output_shape);
Y
Yan Chunwei 已提交
170 171 172 173 174 175 176 177
}

}  // namespace operators
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_OP(reshape, paddle::lite::operators::ReshapeOp);
REGISTER_LITE_OP(reshape2, paddle::lite::operators::Reshape2Op);