slice_compute.h 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <Eigen/Core>
#include <algorithm>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/fluid/eigen.h"
#include "lite/operators/relu_op.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
inline std::vector<int> get_new_data_from_tensorlist(
    const std::vector<lite::Tensor*>& list_new_data_tensor) {
  // get tensor from
  std::vector<int> vec_new_data;
  for (size_t i = 0; i < list_new_data_tensor.size(); ++i) {
    auto tensor = list_new_data_tensor[i];
    CHECK_EQ(tensor->dims(), DDim({1})) << "shape of dim tensor should be [1]";
    vec_new_data.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
  }
  return vec_new_data;
}

inline std::vector<int> get_new_data_from_tensor(
    const Tensor* new_data_tensor) {
  std::vector<int> vec_new_data;
  auto* new_data = new_data_tensor->data<int>();
  vec_new_data =
      std::vector<int>(new_data, new_data + new_data_tensor->numel());
  return vec_new_data;
}

52 53 54 55 56 57
template <size_t D>
void slice_compute(const lite::Tensor* in,
                   lite::Tensor* out,
                   std::vector<int> axes,
                   std::vector<int> starts,
                   std::vector<int> ends,
58 59 60 61 62 63
                   std::vector<int> decrease_axis,
                   lite::Tensor* StartsTensor,
                   lite::Tensor* EndsTensor,
                   std::vector<lite::Tensor*> StartsTensorList,
                   std::vector<lite::Tensor*> EndsTensorList,
                   std::vector<int> infer_flags) {
64 65 66
  auto out_dims = out->dims();
  auto in_dims = in->dims();

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
  bool need_infer = false;
  if (StartsTensor || EndsTensor) {
    need_infer = true;
  } else if (StartsTensorList.size() > 0 || EndsTensorList.size() > 0) {
    need_infer = true;
  }

  if (need_infer) {
    if (StartsTensor) {
      starts = get_new_data_from_tensor(StartsTensor);
    } else if (StartsTensorList.size() > 0) {
      starts = get_new_data_from_tensorlist(StartsTensorList);
    }
    CHECK_EQ(starts.size(), axes.size())
        << "The size of starts must be equal to the size of axes.";
    if (EndsTensor) {
      ends = get_new_data_from_tensor(EndsTensor);
    } else if (EndsTensorList.size() > 0) {
      ends = get_new_data_from_tensorlist(EndsTensorList);
    }
    CHECK_EQ(ends.size(), axes.size())
        << "The size of ends must be equal to the size of axes.";
    out_dims = in_dims;
    int dim_value, start, end;
    for (size_t i = 0; i < axes.size(); ++i) {
      dim_value = out_dims[axes[i]];
      if (dim_value > 0) {
        // when end = start + 1 and start == -1
        if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
          auto ret =
              std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
          if (ret != decrease_axis.end()) {
            ends[i] = 10000000;
          }
        }

        start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
        end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
105 106 107
        start = (std::max)(start, 0);
        end = (std::max)(end, 0);
        end = (std::min)(end, dim_value);
108 109 110 111 112 113 114 115 116 117 118 119 120
        CHECK_GT(end, start) << "end should greater than start";
        out_dims[axes[i]] = end - start;
      }
    }
    out->Resize(out_dims);
    // generate new shape
    if (decrease_axis.size() > 0) {
      std::vector<int64_t> new_out_shape;
      for (size_t i = 0; i < decrease_axis.size(); ++i) {
        CHECK_EQ(out_dims[decrease_axis[i]], 1) << "decrease dim should be 1";
        out_dims[decrease_axis[i]] = 0;
      }

121
      for (size_t i = 0; i < out_dims.size(); ++i) {
122 123 124 125 126 127 128 129 130 131 132 133 134 135
        if (out_dims[i] != 0) {
          new_out_shape.push_back(out_dims[i]);
        }
      }
      if (new_out_shape.size() == 0) {
        new_out_shape.push_back(1);
      }

      DDim new_dims;
      new_dims.ConstructFrom(new_out_shape);
      out_dims = new_dims;
    }
  }

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  // resize out_dims
  if (decrease_axis.size() > 0) {
    if (decrease_axis.size() == (size_t)in_dims.size()) {
      std::vector<int64_t> vec_origin_out_shape(decrease_axis.size(), 1);
      // lite::DDim dims(vec_origin_out_shape);
      out->Resize(vec_origin_out_shape);
    } else {
      std::vector<int64_t> vec_origin_out_shape(
          out_dims.size() + decrease_axis.size(), -1);
      for (size_t i = 0; i < decrease_axis.size(); ++i) {
        vec_origin_out_shape[decrease_axis[i]] = 1;
      }
      int index = 0;
      for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
        if (-1 == vec_origin_out_shape[i]) {
          vec_origin_out_shape[i] = out_dims[index];
          ++index;
        }
      }
      // lite::DDim dims(vec_origin_out_shape);
      out->Resize(vec_origin_out_shape);
    }
  }

160
  out->mutable_data<float>();
161 162 163 164 165 166 167 168 169 170 171 172 173 174

  auto new_out_dims = out->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = new_out_dims[i];
  }
  int start;
  for (size_t i = 0; i < axes.size(); ++i) {
    start = starts[i];
    if (start < 0) {
      start = (start + in_dims[axes[i]]);
    }
175
    start = (std::max)(start, 0);
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    offsets[axes[i]] = start;
  }
  auto in_t =
      lite::fluid::EigenTensor<float, D, Eigen::RowMajor, Eigen::DenseIndex>::
          From(*in, in->dims());
  auto out_t =
      lite::fluid::EigenTensor<float, D, Eigen::RowMajor, Eigen::DenseIndex>::
          From(*out, new_out_dims);
  out_t = in_t.slice(offsets, extents);

  out->Resize(out_dims);
}

template <typename T>
void slice_compute_(const lite::Tensor* Input,
                    lite::Tensor* Out,
                    std::vector<int> axes,
                    std::vector<int> starts,
                    std::vector<int> ends,
195 196 197 198 199 200
                    std::vector<int> decrease_axis,
                    lite::Tensor* StartsTensor,
                    lite::Tensor* EndsTensor,
                    std::vector<lite::Tensor*> StartsTensorList,
                    std::vector<lite::Tensor*> EndsTensorList,
                    std::vector<int> infer_flags) {
201 202 203
  int rank = Input->dims().size();
  switch (rank) {
    case 1:
204 205 206 207 208 209 210 211 212 213 214
      slice_compute<1>(Input,
                       Out,
                       axes,
                       starts,
                       ends,
                       decrease_axis,
                       StartsTensor,
                       EndsTensor,
                       StartsTensorList,
                       EndsTensorList,
                       infer_flags);
215 216
      break;
    case 2:
217 218 219 220 221 222 223 224 225 226 227
      slice_compute<2>(Input,
                       Out,
                       axes,
                       starts,
                       ends,
                       decrease_axis,
                       StartsTensor,
                       EndsTensor,
                       StartsTensorList,
                       EndsTensorList,
                       infer_flags);
228 229
      break;
    case 3:
230 231 232 233 234 235 236 237 238 239 240
      slice_compute<3>(Input,
                       Out,
                       axes,
                       starts,
                       ends,
                       decrease_axis,
                       StartsTensor,
                       EndsTensor,
                       StartsTensorList,
                       EndsTensorList,
                       infer_flags);
241 242
      break;
    case 4:
243 244 245 246 247 248 249 250 251 252 253
      slice_compute<4>(Input,
                       Out,
                       axes,
                       starts,
                       ends,
                       decrease_axis,
                       StartsTensor,
                       EndsTensor,
                       StartsTensorList,
                       EndsTensorList,
                       infer_flags);
254 255
      break;
    case 5:
256 257 258 259 260 261 262 263 264 265 266
      slice_compute<5>(Input,
                       Out,
                       axes,
                       starts,
                       ends,
                       decrease_axis,
                       StartsTensor,
                       EndsTensor,
                       StartsTensorList,
                       EndsTensorList,
                       infer_flags);
267 268
      break;
    case 6:
269 270 271 272 273 274 275 276 277 278 279
      slice_compute<6>(Input,
                       Out,
                       axes,
                       starts,
                       ends,
                       decrease_axis,
                       StartsTensor,
                       EndsTensor,
                       StartsTensorList,
                       EndsTensorList,
                       infer_flags);
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
      break;
  }
}

template <typename T>
class SliceCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::SliceParam;

  void Run() override {
    auto& param = *param_.get_mutable<param_t>();
    slice_compute_<T>(param.X,
                      param.Out,
                      param.axes,
                      param.starts,
                      param.ends,
296 297 298 299 300 301
                      param.decrease_axis,
                      param.StartsTensor,
                      param.EndsTensor,
                      param.StartsTensorList,
                      param.EndsTensorList,
                      param.infer_flags);
302 303 304 305 306 307 308 309 310
  }

  virtual ~SliceCompute() = default;
};

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle