set_value_op.h 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <string>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
26
#include "paddle/fluid/operators/utils.h"
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

inline std::string GetValueName(framework::proto::VarType::Type data_type) {
  std::string value_name;
  switch (data_type) {
    case framework::proto::VarType::INT32:
      value_name = "int32_values";
      break;
    case framework::proto::VarType::INT64:
      value_name = "int64_values";
      break;
    case framework::proto::VarType::FP32:
      value_name = "fp32_values";
      break;
46 47 48
    case framework::proto::VarType::FP64:
      value_name = "fp64_values";
      break;
49 50 51
    case framework::proto::VarType::BOOL:
      value_name = "bool_values";
      break;
52

53 54 55 56 57 58 59 60 61
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type(code %d) for SetValue operator, only "
          "supports bool, int32, float32 and int64.",
          data_type));
  }
  return value_name;
}

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
inline void CheckAndUpdateSlice(const framework::DDim in_dims,
                                const std::vector<int64_t> axes,
                                std::vector<int64_t>* starts,
                                std::vector<int64_t>* ends,
                                std::vector<int64_t>* steps) {
  for (size_t i = 0; i < axes.size(); ++i) {
    int64_t axis = axes[i];
    int64_t dim_value = in_dims[axis];

    int64_t start =
        (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
    int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
    start = std::max(start, static_cast<int64_t>(0));
    end = std::min(end, dim_value);

    int64_t step = (*steps)[i];
    PADDLE_ENFORCE_NE(
        step, 0, platform::errors::InvalidArgument(
                     "Step should not be 0, but received step = %d.", step));
    if (step > 0) {
      start = std::min(start, dim_value);
      end = std::max(end, static_cast<int64_t>(0));
      PADDLE_ENFORCE_GT(
          end, start,
          platform::errors::InvalidArgument(
              "When step > 0, end should be greater than start, but "
              "received end = %d, start = %d.",
              end, start));
    } else {
      // NOTE(liym27): When step < 0, start should less and equal to dim_value-1
      // "end is -1" means contain the 0-th element of this axis.
      start = std::min(start, dim_value - 1);
      end = std::max(end, static_cast<int64_t>(-1));
      PADDLE_ENFORCE_GT(
          start, end,
          platform::errors::InvalidArgument(
              "When step < 0, start should be greater than end, but "
              "received start = %d, end = %d.",
              start, end));
    }

    (*starts)[i] = start;
    (*ends)[i] = end;
  }
}

108
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
109 110 111 112
                                    const std::vector<int64_t>& axes,
                                    const std::vector<int64_t>& starts,
                                    const std::vector<int64_t>& ends,
                                    const std::vector<int64_t>& steps) {
113 114 115 116
  framework::DDim slice_dims(in_dims);

  for (size_t i = 0; i < axes.size(); ++i) {
    int64_t axis = axes[i];
117 118 119
    int64_t start = starts[i];
    int64_t end = ends[i];
    int64_t step = steps[i];
120

121 122 123 124 125
    if (step > 0) {
      slice_dims[axis] = (end - start + step - 1) / step;
    } else {
      slice_dims[axis] = (end - start + step + 1) / step;
    }
126 127 128 129
  }
  return slice_dims;
}

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
inline framework::DDim GetDecreasedDims(
    const framework::DDim slice_dims,
    const std::vector<int64_t>& decrease_axes) {
  // Get dims after decreasing axes.
  framework::DDim decreased_dims(slice_dims);
  if (decrease_axes.size() > 0) {
    for (size_t i = 0; i < decrease_axes.size(); ++i) {
      int64_t axis = decrease_axes[i];
      PADDLE_ENFORCE_EQ(
          decreased_dims[axis], 1,
          platform::errors::InvalidArgument("decrease dim should be 1"));
      decreased_dims[axis] = 0;
    }

    std::vector<int64_t> new_shape;
    for (int i = 0; i < decreased_dims.size(); ++i) {
      if (decreased_dims[i] != 0) {
        new_shape.push_back(decreased_dims[i]);
      }
    }

    // NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
    // uses [1] instead.
    if (new_shape.size() == 0) {
      new_shape.push_back(1);
    }

    decreased_dims = framework::make_ddim(new_shape);
  }
  return decreased_dims;
}

162 163 164 165
template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const {
166
    const int rank = ctx.Input<framework::LoDTensor>("Input")->dims().size();
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

    // TODO(liym27): A more elegent code to do this. C++ has to make template
    //  integer as constant, but we had better have alternative writing in the
    //  future.
    switch (rank) {
      case 1:
        SetValueCompute<1>(ctx);
        break;
      case 2:
        SetValueCompute<2>(ctx);
        break;
      case 3:
        SetValueCompute<3>(ctx);
        break;
      case 4:
        SetValueCompute<4>(ctx);
        break;
      case 5:
        SetValueCompute<5>(ctx);
        break;
      case 6:
        SetValueCompute<6>(ctx);
        break;
190 191 192
      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
            "The rank of input should be less than 7, but received %d.", rank));
193 194 195 196 197 198 199
    }
  }

 private:
  template <size_t D>
  void SetValueCompute(const framework::ExecutionContext& ctx) const {
    auto* in = ctx.Input<framework::LoDTensor>("Input");
200
    auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor");
201 202
    auto* out = ctx.Output<framework::LoDTensor>("Out");

203 204 205 206 207 208
    auto starts_tensor_list =
        ctx.MultiInput<framework::Tensor>("StartsTensorList");
    auto ends_tensor_list = ctx.MultiInput<framework::Tensor>("EndsTensorList");
    auto steps_tensor_list =
        ctx.MultiInput<framework::Tensor>("StepsTensorList");

209 210 211
    auto axes = ctx.Attr<std::vector<int64_t>>("axes");
    auto starts = ctx.Attr<std::vector<int64_t>>("starts");
    auto ends = ctx.Attr<std::vector<int64_t>>("ends");
212
    auto steps = ctx.Attr<std::vector<int64_t>>("steps");
213
    auto shape = ctx.Attr<std::vector<int64_t>>("shape");
214
    auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
215

216
    auto dtype = in->type();
217 218 219 220 221 222 223 224 225
    if (!starts_tensor_list.empty()) {
      starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
    }
    if (!ends_tensor_list.empty()) {
      ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
    }
    if (!steps_tensor_list.empty()) {
      steps = GetDataFromTensorList<int64_t>(steps_tensor_list);
    }
226 227

    auto in_dims = in->dims();
228 229
    CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
    auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
230
    auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248

    auto place = ctx.GetPlace();
    auto& eigen_place =
        *ctx.template device_context<DeviceContext>().eigen_device();

    // Here copy data from input to avoid data loss at PE and Graph level.
    // TODO(liym27): Speed up in the future version.
    // - Q: Why don't call ShareDataWith to speed up?
    // - A: Because it's not supported to ShareDataWith on OP's input and output
    // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
    // - Q: Why don't delete Input, after all, the input and output are the same
    // Tensor at program level?
    // - A: If deleting Input, the graph will be complex, such as there will
    // be two ops points to the output in graph: op1 -> output <- set_value.
    // In this case, we have to find a way to handle the running order of
    // set_value is what we want.
    TensorCopy(*in, place, out);

249 250 251
    Tensor slice_tensor(dtype), pad_tensor(dtype);
    slice_tensor.mutable_data<T>(slice_dims, place);
    pad_tensor.mutable_data<T>(in_dims, place);
252

253
    auto pad_e = framework::EigenTensor<T, D>::From(pad_tensor, in_dims);
254
    auto out_e = framework::EigenTensor<T, D>::From(*out);
255
    auto slice_e = framework::EigenTensor<T, D>::From(slice_tensor, slice_dims);
256 257

    // Step 1: Set the value of out at `_index` to zero
258 259 260 261 262
    slice_e.device(eigen_place) = slice_e.constant(T(0));

    auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
    auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
    auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
263 264

    for (size_t i = 0; i < D; ++i) {
265 266 267
      starts_indices[i] = 0;
      ends_indices[i] = slice_dims[i];
      strides_indices[i] = 1;
268
    }
269 270 271 272 273
    for (size_t i = 0; i < axes.size(); i++) {
      int axis_index = axes[i];
      starts_indices[axis_index] = starts[i];
      ends_indices[axis_index] = ends[i];
      strides_indices[axis_index] = steps[i];
274 275
    }

276 277
    out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
        .device(eigen_place) = slice_e;
278 279 280

    // Step 2: Set a tensor with the same shape as out tensor. And its data at
    // '_index' is the same as value_tensor, and data out of '_index' to zero
281

282
    // - Step 2.1 Set slice tensor with value
283 284 285 286 287 288 289 290 291 292 293 294 295 296

    // NOTE(liym27): [ Why resize slice_tensor here? ]
    // A: When do broadcasting on slice_tensor and value_tensor, the shape of
    // slice_tensor should be decreased dims.
    // e.g.
    //  x[:,0] = value_tensor
    // x's shape = [3, 4], value_tensor's shape = [3]
    // We get slice_dims = [3, 1],  decrease_slice_dims = [3]
    // If do broadcasting on Tensor with shape [3, 1] and [3], the result's
    // shape is [3, 3], which cross the border;
    // If do broadcasting on Tensor with shape [3] and [3], the result's shape
    // is [3], which is right.

    slice_tensor.Resize(decrease_slice_dims);
297 298 299
    if (value_tensor != nullptr) {
      // ElementwiseComputeEx can do broadcasting
      ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
300
          ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
301 302
    } else {
      Tensor value_t(dtype);
303
      auto value_dims = framework::make_ddim(shape);
304 305 306 307 308
      value_t.mutable_data<T>(value_dims, place);
      auto value_name = GetValueName(dtype);
      CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
      value_t.Resize(value_dims);
      ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
309
          ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
310
    }
311
    slice_tensor.Resize(slice_dims);
312

313 314 315 316
    // - Step 2.2 Pad slice tensor with 0
    pad_e.device(eigen_place) = pad_e.constant(T(0));
    pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
        .device(eigen_place) = slice_e;
317 318 319 320 321 322 323 324

    // Step 3: Set out tensor with value_tensor
    out_e.device(eigen_place) = out_e - pad_e;
  }
};

}  // namespace operators
}  // namespace paddle