stack_op.cu 7.6 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include <algorithm>
#include <limits>
#include <vector>
X
Xin Pan 已提交
18
#include "paddle/fluid/operators/stack_op.h"
19
#include "paddle/fluid/platform/gpu_launch_config.h"
X
Xin Pan 已提交
20 21 22 23

namespace plat = paddle::platform;
namespace ops = paddle::operators;

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
namespace paddle {
namespace operators {

template <typename T, typename IntType>
__global__ void StackCUDAKernel(T** input_ptrs, int split_size, int rows,
                                int cols, T* __restrict__ output) {
  IntType grid_x = blockIdx.x * blockDim.x + threadIdx.x;

  for (; grid_x < cols; grid_x += blockDim.x * gridDim.x) {
    IntType grid_y = blockIdx.y * blockDim.y + threadIdx.y;

    IntType split = grid_x / split_size;
    const T* input_ptr = input_ptrs[split];
    IntType col_offset = grid_x % split_size;
#pragma unroll
    for (; grid_y < rows; grid_y += blockDim.y * gridDim.y) {
      output[grid_y * cols + grid_x] =
          input_ptr[grid_y * split_size + col_offset];
    }
  }
}

template <typename T>
class StackGPUKernel : public framework::OpKernel<T> {
  using Tensor = framework::LoDTensor;

 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto x = ctx.MultiInput<Tensor>("X");
    auto* y = ctx.Output<Tensor>("Y");

    int axis = ctx.Attr<int>("axis");
    if (axis < 0) axis += (x[0]->dims().size() + 1);

    int n = static_cast<int>(x.size());
    auto* y_data = y->mutable_data<T>(ctx.GetPlace());
    std::vector<const T*> x_datas(n);
    for (int i = 0; i < n; i++) {
      x_datas[i] = x[i]->data<T>();
    }

    auto& dev_ctx = ctx.template device_context<plat::CUDADeviceContext>();
    auto tmp_x_data = memory::Alloc(dev_ctx, x_datas.size() * sizeof(T*));
67
    memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
                 tmp_x_data->ptr(), platform::CPUPlace(),
                 reinterpret_cast<void*>(x_datas.data()),
                 x_datas.size() * sizeof(T*), dev_ctx.stream());

    // Split x dim from axis to matrix
    int x_row = 1, x_col = 1;
    for (int i = 0; i < axis; ++i) {
      x_row *= x[0]->dims()[i];
    }
    x_col = x[0]->numel() / x_row;
    int out_col = x_col * n;

    auto config = GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);

    if (y->numel() < std::numeric_limits<int32_t>::max()) {
      StackCUDAKernel<T,
                      int32_t><<<config.block_per_grid, config.thread_per_block,
                                 0, dev_ctx.stream()>>>(
          reinterpret_cast<T**>(tmp_x_data->ptr()), x_col, x_row, out_col,
          y_data);
    } else {
      StackCUDAKernel<T,
                      int64_t><<<config.block_per_grid, config.thread_per_block,
                                 0, dev_ctx.stream()>>>(
          reinterpret_cast<T**>(tmp_x_data->ptr()), x_col, x_row, out_col,
          y_data);
    }
  }
};

template <typename T, typename IntType>
J
Jiawei Wang 已提交
99 100 101 102
__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
                                        int pre_dim_size, int split_dim_size,
                                        int suf_dim_size, int num_split,
                                        T** output_ptrs) {
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
  assert(blockDim.y == 1);
  assert(blockDim.z == 1);
  // In this case they are equal
  assert(split_dim_size % num_split == 0);

  IntType size = pre_dim_size * split_dim_size * suf_dim_size;
  IntType each_dim_size = split_dim_size / num_split;

  for (IntType offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size;
       offset += blockDim.x * gridDim.x) {
    IntType i = offset / (split_dim_size * suf_dim_size);
    IntType j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size;
    IntType k = offset % suf_dim_size;

    T* output = output_ptrs[j / each_dim_size];
J
Jiawei Wang 已提交
118 119 120
    if (output == nullptr) {
      return;
    }
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    IntType output_ind = i * each_dim_size * suf_dim_size +
                         (j % each_dim_size) * suf_dim_size + k;
    *(output + output_ind) = input[offset];
  }
}

template <typename T>
class StackGradGPUKernel : public framework::OpKernel<T> {
  using Tensor = framework::LoDTensor;

 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
    int axis = ctx.Attr<int>("axis");
    if (axis < 0) axis += dy->dims().size();

    int n = dy->dims()[axis];
    PADDLE_ENFORCE_EQ(n, dx.size(),
                      platform::errors::InvalidArgument(
                          "Output dx size should be equal to n, but"
                          " received n is:%d dx size is:%d.",
                          n, dx.size()));

    // dx is output, so save each data address, then copy each dy into dx_data
    std::vector<T*> outputs(n);
    auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
    for (size_t j = 0; j < dx.size(); ++j) {
J
Jiawei Wang 已提交
149 150 151
      if (dx[j] == nullptr) {
        outputs[j] = nullptr;
      }
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
      if (out_var_names[j] != framework::kEmptyVarName &&
          dx[j]->numel() != 0UL) {
        T* ptr = dx[j]->mutable_data<T>(ctx.GetPlace());
        outputs[j] = ptr;
      } else {
        outputs[j] = nullptr;
      }
    }
    auto dy_data = dy->data<T>();
    // each dx should have same shape
    int dy_pre = 1, dy_suf = 1;
    auto dy_dims = dy->dims();
    int split_dim = n;
    for (int i = 0; i < axis; ++i) {
      dy_pre *= dy_dims[i];
    }
    dy_suf = dy->numel() / (split_dim * dy_pre);

    auto& dev_ctx = ctx.template device_context<plat::CUDADeviceContext>();
    auto tmp_out_data = memory::Alloc(dev_ctx, outputs.size() * sizeof(T*));
172
    memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
173 174 175 176 177 178 179
                 tmp_out_data->ptr(), platform::CPUPlace(),
                 reinterpret_cast<void*>(outputs.data()),
                 outputs.size() * sizeof(T*), dev_ctx.stream());

    auto config = GetGpuLaunchConfig1D(dev_ctx, dy_pre * split_dim * dy_suf);

    if (dy->numel() < std::numeric_limits<int32_t>::max()) {
J
Jiawei Wang 已提交
180
      UnStackHelperCUDAKernel<
181 182 183 184 185
          T, int32_t><<<config.block_per_grid.x, config.thread_per_block.x, 0,
                        dev_ctx.stream()>>>(
          dy_data, dy_pre, split_dim, dy_suf, split_dim,
          reinterpret_cast<T**>(tmp_out_data->ptr()));
    } else {
J
Jiawei Wang 已提交
186
      UnStackHelperCUDAKernel<
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
          T, int64_t><<<config.block_per_grid.x, config.thread_per_block.x, 0,
                        dev_ctx.stream()>>>(
          dy_data, dy_pre, split_dim, dy_suf, split_dim,
          reinterpret_cast<T**>(tmp_out_data->ptr()));
    }
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(stack, ops::StackGPUKernel<float>,
                        ops::StackGPUKernel<double>, ops::StackGPUKernel<int>,
                        ops::StackGPUKernel<int64_t>,
                        ops::StackGPUKernel<plat::float16>);

REGISTER_OP_CUDA_KERNEL(stack_grad, ops::StackGradGPUKernel<float>,
                        ops::StackGradGPUKernel<double>,
                        ops::StackGradGPUKernel<int>,
                        ops::StackGradGPUKernel<int64_t>,
                        ops::StackGradGPUKernel<plat::float16>);