roll_op.h 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
19 20
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;

template <typename T>
inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim,
                            int64_t shift) {
  if (dim < 0) {
    dim += input_dim.size();
  }
35 36 37
  if (input_dim[dim] == 0) {
    return;
  }
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  shift = shift % input_dim[dim];
  if (shift < 0) {
    shift += input_dim[dim];
  }

  auto outer_loops = 1;
  for (auto i = 0; i < dim; i++) {
    outer_loops *= input_dim[i];
  }
  auto slice_width = 1;
  for (auto i = dim + 1; i < input_dim.size(); i++) {
    slice_width *= input_dim[i];
  }

  VLOG(3) << "shift_along_dim_debug: input_dim: " << input_dim
          << "; dim: " << dim << "; shift: " << shift
          << "; outer_loops: " << outer_loops
          << "; slice_width: " << slice_width;
  if (shift == 0) {
    return;
  }

  std::vector<T> head;
  auto head_size = slice_width * (input_dim[dim] - shift);
  head.resize(head_size);

  for (auto i = 0; i < outer_loops; i++) {
    for (auto j = 0; j < head_size; j++) {
      head[j] = data[i * input_dim[dim] * slice_width + j];
    }
    for (auto j = input_dim[dim] - shift; j < input_dim[dim]; j++) {
      auto dst_pos = j - input_dim[dim] + shift;
      for (auto k = 0; k < slice_width; k++) {
        data[(i * input_dim[dim] + dst_pos) * slice_width + k] =
            data[(i * input_dim[dim] + j) * slice_width + k];
      }
    }
    for (auto j = 0; j < head_size; j++) {
      data[(i * input_dim[dim] + shift) * slice_width + j] = head[j];
    }
  }
}

template <typename DeviceContext, typename T>
class RollKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* input_var = context.InputVar("X");
    auto* output_var = context.OutputVar("Out");
    auto& input = input_var->Get<LoDTensor>();
    auto* output = output_var->GetMutable<LoDTensor>();
    std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
90 91 92 93 94 95 96 97 98 99
    if (context.HasInput("ShiftsTensor")) {
      const auto* shifts_tensor =
          context.Input<framework::Tensor>("ShiftsTensor");
      PADDLE_ENFORCE_EQ(
          shifts_tensor->dims().size(), 1,
          platform::errors::InvalidArgument(
              "The rank of ShiftsTensor is expected to be 1, got %s",
              shifts_tensor->dims().size()));
      shifts = GetDataFromTensor<int64_t>(shifts_tensor);
    }
Y
yaoxuefeng 已提交
100
    std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
101 102 103 104 105

    std::vector<T> out_vec;
    TensorToVector(input, context.device_context(), &out_vec);

    size_t nums = shifts.size();
S
sunli 已提交
106 107 108 109 110 111 112
    DDim input_dim = input.dims();

    // axis = none, reshape to 1-D tensor
    if (dims.size() == 0) {
      dims.push_back(0l);
      input_dim = framework::Dim<1>(out_vec.size());
    }
113 114 115 116 117

    for (size_t i = 0; i < nums; i++) {
      PADDLE_ENFORCE_EQ(
          dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
          platform::errors::OutOfRange(
Y
yaoxuefeng 已提交
118 119
              "Attr(axis[%d]) is out of range, It's expected "
              "to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.",
120 121 122 123 124
              i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
      shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
    }
    output->mutable_data<T>(context.GetPlace());
    framework::TensorFromVector(out_vec, context.device_context(), output);
S
sunli 已提交
125
    output->Resize(input.dims());
126 127 128 129 130 131 132 133 134 135 136 137
  }
};

template <typename DeviceContext, typename T>
class RollGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* input_var = context.InputVar(framework::GradVarName("Out"));
    auto* output_var = context.OutputVar(framework::GradVarName("X"));
    auto& input = input_var->Get<LoDTensor>();
    auto* output = output_var->GetMutable<LoDTensor>();
    std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
138 139 140 141 142
    if (context.HasInput("ShiftsTensor")) {
      const auto* shifts_tensor =
          context.Input<framework::Tensor>("ShiftsTensor");
      shifts = GetDataFromTensor<int64_t>(shifts_tensor);
    }
Y
yaoxuefeng 已提交
143
    std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
144 145 146 147 148

    std::vector<T> out_vec;
    TensorToVector(input, context.device_context(), &out_vec);

    size_t nums = shifts.size();
S
sunli 已提交
149 150 151 152 153 154 155
    DDim input_dim = input.dims();

    // axis = none, reshape to 1-D tensor
    if (dims.size() == 0) {
      dims.push_back(0l);
      input_dim = framework::Dim<1>(out_vec.size());
    }
156 157 158 159 160 161

    for (size_t i = 0; i < nums; i++) {
      shift_along_dim(out_vec.data(), input_dim, dims[i], 0 - shifts[i]);
    }
    output->mutable_data<T>(context.GetPlace());
    framework::TensorFromVector(out_vec, context.device_context(), output);
S
sunli 已提交
162
    output->Resize(input.dims());
163 164 165 166 167
  }
};

}  // namespace operators
}  // namespace paddle