slice_op.cc 6.0 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/slice_op.h"
#include <algorithm>
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool SliceOp::CheckShape() const {
  CHECK_OR_FALSE(param_.X);
  CHECK_OR_FALSE(param_.Out);
25
  CHECK_LT(param_.X->dims().size(), 7u)
26
      << "The rank of input X should be less than 7";
Y
Yan Chunwei 已提交
27 28 29
  return true;
}

30
bool SliceOp::InferShapeImpl() const {
Y
Yan Chunwei 已提交
31 32 33 34
  CHECK_OR_FALSE(param_.Out);
  // TODO(Superjomn) Enable data sharing.
  auto in_dims = param_.X->dims();
  auto out_dims = in_dims;
35 36
  // CHECK_EQ(param_.starts.size(), param_.ends.size())
  //    << "for slice op starts and ends must be equal";
Y
Yan Chunwei 已提交
37 38 39 40 41 42
  int dim_value, start, end;
  auto axes = param_.axes;
  auto starts = param_.starts;
  auto ends = param_.ends;
  auto decrease_axis = param_.decrease_axis;
  for (size_t i = 0; i < axes.size(); ++i) {
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    CHECK_LT(param_.axes[i], in_dims.size()) << "The index of dimension in "
                                                "axes must be less than the "
                                                "size of input shape.";
    if (param_.infer_flags[i] == -1) {
      out_dims[axes[i]] = -1;
    } else {
      // infer out_dim shape
      dim_value = out_dims[axes[i]];
      if (dim_value > 0) {
        start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
        end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
        start = std::max(start, 0);
        end = std::max(end, 0);
        end = std::min(end, dim_value);
        out_dims[axes[i]] = end - start;
      }
Y
Yan Chunwei 已提交
59 60
    }
  }
61
  // generate new shape
Y
Yan Chunwei 已提交
62 63 64
  if (decrease_axis.size() > 0) {
    std::vector<int64_t> new_out_shape;
    for (size_t i = 0; i < decrease_axis.size(); ++i) {
65 66 67
      if (param_.infer_flags[i] != -1) {
        CHECK_EQ(out_dims[decrease_axis[i]], 1) << "decrease dim should be 1";
      }
Y
Yan Chunwei 已提交
68 69
      out_dims[decrease_axis[i]] = 0;
    }
70
    for (size_t i = 0; i < out_dims.size(); ++i) {
Y
Yan Chunwei 已提交
71 72 73 74 75 76 77 78 79 80 81 82
      if (out_dims[i] != 0) {
        new_out_shape.push_back(out_dims[i]);
      }
    }
    if (new_out_shape.size() == 0) {
      new_out_shape.push_back(1);
    }
    DDim new_dims;
    new_dims.ConstructFrom(new_out_shape);
    out_dims = new_dims;
  }
  param_.Out->Resize(out_dims);
83 84 85
  if (axes[0] != 0) {
    param_.Out->set_lod(param_.X->lod());
  }
Y
Yan Chunwei 已提交
86 87 88 89 90 91 92 93 94 95 96
  return true;
}

bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
  param_.X =
      scope->FindVar(opdesc.Input("Input").front())->GetMutable<lite::Tensor>();
  param_.Out =
      scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
  CHECK(param_.X);
  CHECK(param_.Out);
  param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
97 98 99 100 101 102 103 104

  if (opdesc.HasAttr("infer_flags")) {
    param_.infer_flags = opdesc.GetAttr<std::vector<int>>("infer_flags");
  } else {
    // Initialize infer_flags with 1.
    // To be compatible with other op tests in which infer_flags is not set.
    param_.infer_flags = std::vector<int>(param_.axes.size(), 1);
  }
105 106 107
  if (opdesc.HasAttr("decrease_axis")) {
    param_.decrease_axis = opdesc.GetAttr<std::vector<int>>("decrease_axis");
  }
108 109 110

  // The priority: StartsTensor > StartsTensorList > attr(starts).
  // The priority: EndsTensor > EndsTensorList > attr(ends).
111
  size_t starts_size, ends_size;
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  if (opdesc.HasAttr("starts")) {
    param_.starts = opdesc.GetAttr<std::vector<int>>("starts");
  }
  if (opdesc.HasAttr("ends")) {
    param_.ends = opdesc.GetAttr<std::vector<int>>("ends");
  }
  starts_size = param_.starts.size();
  ends_size = param_.ends.size();

  if (opdesc.HasInput("StartsTensorList") &&
      !opdesc.Input("StartsTensorList").empty()) {
    LOG(INFO) << "opdesc input size "
              << opdesc.Input("StartsTensorList").size();
    LOG(INFO) << "param init size " << param_.StartsTensorList.size();
    auto StartsTensorList = opdesc.Input("StartsTensorList");
    param_.StartsTensorList.clear();
    for (auto var : StartsTensorList) {
      param_.StartsTensorList.push_back(
          scope->FindVar(var)->GetMutable<lite::Tensor>());
    }
132
    CHECK_GT(param_.StartsTensorList.size(), 0u)
133 134 135 136 137 138 139 140 141 142 143
        << "StartsTensorList size can't be zero";
    starts_size = param_.StartsTensorList.size();
  }
  if (opdesc.HasInput("EndsTensorList") &&
      !opdesc.Input("EndsTensorList").empty()) {
    auto EndsTensorList = opdesc.Input("EndsTensorList");
    param_.EndsTensorList.clear();
    for (auto var : EndsTensorList) {
      param_.EndsTensorList.push_back(
          scope->FindVar(var)->GetMutable<lite::Tensor>());
    }
144
    CHECK_GT(param_.EndsTensorList.size(), 0u)
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        << "EndsTensorList size can't be zero";
    ends_size = param_.EndsTensorList.size();
  }

  if (opdesc.HasInput("StartsTensor") &&
      !opdesc.Input("StartsTensor").empty()) {
    param_.StartsTensor = scope->FindVar(opdesc.Input("StartsTensor").front())
                              ->GetMutable<lite::Tensor>();
  } else {
    CHECK_EQ(starts_size, param_.axes.size())
        << "The size of starts must be equal to the size of axes.";
  }
  if (opdesc.HasInput("EndsTensor") && !opdesc.Input("EndsTensor").empty()) {
    param_.EndsTensor = scope->FindVar(opdesc.Input("EndsTensor").front())
                            ->GetMutable<lite::Tensor>();
  } else {
    CHECK_EQ(ends_size, param_.axes.size())
        << "The size of ends must be equal to the size of axes.";
  }
Y
Yan Chunwei 已提交
164 165 166 167 168 169 170 171
  return true;
}

}  // namespace operators
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_OP(slice, paddle::lite::operators::SliceOp);