unstack_op.h 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14 15 16

#pragma once

17
#include <memory>
D
dzhwinter 已提交
18
#include "paddle/fluid/framework/op_registry.h"
19 20
#include "paddle/fluid/platform/for_range.h"

21
#if defined(__NVCC__) || defined(__HIPCC__)
22 23 24
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#endif
D
dzhwinter 已提交
25 26 27 28

namespace paddle {
namespace operators {

29 30 31 32
template <typename VecXType, typename T>
struct StackFunctor {
  HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post)
      : x_(x), y_(y), n_(n), post_(post) {}
D
dzhwinter 已提交
33

34 35 36 37 38
  HOSTDEVICE void operator()(int idx) {
    int i = idx / (n_ * post_);
    int which_x = idx / post_ - i * n_;
    int x_index = i * post_ + idx % post_;
    y_[idx] = x_[which_x][x_index];
D
dzhwinter 已提交
39 40 41
  }

 private:
42 43 44 45
  VecXType x_;
  T *y_;
  int n_;
  int post_;
D
dzhwinter 已提交
46 47
};

48 49 50 51
template <typename VecDxType, typename T>
struct StackGradFunctor {
  HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post)
      : dx_(dx), dy_(dy), n_(n), post_(post) {}
D
dzhwinter 已提交
52

53 54 55 56 57
  HOSTDEVICE void operator()(int idx) {
    int i = idx / (n_ * post_);
    int which_x = idx / post_ - i * n_;
    int x_index = i * post_ + idx % post_;
    dx_[which_x][x_index] = dy_[idx];
D
dzhwinter 已提交
58
  }
59 60 61 62 63 64

 private:
  VecDxType dx_;
  const T *dy_;
  int n_;
  int post_;
D
dzhwinter 已提交
65 66
};

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
template <typename DeviceContext, typename VecXType, typename T>
static inline void StackFunctorForRange(const DeviceContext &ctx,
                                        const VecXType &x, T *y, int total_num,
                                        int n, int post) {
  platform::ForRange<DeviceContext> for_range(ctx, total_num);
  for_range(StackFunctor<VecXType, T>(x, y, n, post));
}

template <typename DeviceContext, typename VecDxType, typename T>
static inline void StackGradFunctorForRange(const DeviceContext &ctx,
                                            const VecDxType &dx, const T *dy,
                                            int total_num, int n, int post) {
  platform::ForRange<DeviceContext> for_range(ctx, total_num);
  for_range(StackGradFunctor<VecDxType, T>(dx, dy, n, post));
}

template <typename DeviceContext, typename T>
class UnStackGradKernel : public framework::OpKernel<T> {
  using Tensor = framework::LoDTensor;

D
dzhwinter 已提交
87
 public:
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto x = ctx.MultiInput<Tensor>(framework::GradVarName("Y"));
    auto *y = ctx.Output<Tensor>(framework::GradVarName("X"));

    int axis = ctx.Attr<int>("axis");
    if (axis < 0) axis += (x[0]->dims().size() + 1);

    int n = static_cast<int>(x.size());
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
    std::vector<const T *> x_datas(n);
    for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();

    int pre = 1;
    int post = 1;
    auto &dim = x[0]->dims();
    for (auto i = 0; i < axis; ++i) pre *= dim[i];
    for (auto i = axis; i < dim.size(); ++i) post *= dim[i];

106
#if defined(__NVCC__) || defined(__HIPCC__)
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    int total_num = pre * n * post;
    auto &dev_ctx = ctx.template device_context<DeviceContext>();

    thrust::device_vector<const T *> device_x_vec(x_datas);
    auto x_data_arr = device_x_vec.data().get();

    StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);

    // Wait() must be called because device_x_vec may be destructed before
    // kernel ends
    dev_ctx.Wait();
#else
    auto x_data_arr = x_datas.data();

    size_t x_offset = 0;
    size_t y_offset = 0;
    for (int i = 0; i < pre; i++) {
      for (int j = 0; j < n; j++) {
        std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
                    post * sizeof(T));
        y_offset += post;
      }
      x_offset += post;
    }
#endif
D
dzhwinter 已提交
132 133 134
  }
};

135 136 137
template <typename DeviceContext, typename T>
class UnStackKernel : public framework::OpKernel<T> {
  using Tensor = framework::LoDTensor;
D
dzhwinter 已提交
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *dy = ctx.Input<Tensor>("X");
    auto dx = ctx.MultiOutput<Tensor>("Y");
    int axis = ctx.Attr<int>("axis");
    if (axis < 0) axis += dy->dims().size();

    int n = dy->dims()[axis];
    std::vector<T *> dx_datas(n);  // NOLINT
    for (int i = 0; i < n; i++) {
      dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
    }
    auto dy_data = dy->data<T>();

    int pre = 1;
    for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
    int total_num = dy->numel();
    int post = total_num / (n * pre);

    auto &dev_ctx = ctx.template device_context<DeviceContext>();
159
#if defined(__NVCC__) || defined(__HIPCC__)
160 161 162 163 164 165
    thrust::device_vector<T *> device_dx_vec(dx_datas);
    auto dx_data_arr = device_dx_vec.data().get();
#else
    auto dx_data_arr = dx_datas.data();
#endif
    StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
166
#if defined(__NVCC__) || defined(__HIPCC__)
167 168 169 170
    // Wait() must be called because device_dx_vec may be destructed before
    // kernel ends
    dev_ctx.Wait();
#endif
D
dzhwinter 已提交
171 172 173 174 175
  }
};

}  // namespace operators
}  // namespace paddle