cumsum_op.cu 11.2 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
E
emailweixu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16 17 18
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
19 20 21 22 23 24 25
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Y
Yi Wang 已提交
26
#include "paddle/fluid/operators/cum_op.h"
27
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
E
emailweixu 已提交
28

29 30 31 32 33 34
using Tensor = paddle::framework::Tensor;
using LoDTensor = paddle::framework::LoDTensor;

namespace paddle {
namespace operators {

W
wangchaochaohu 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
template <typename T, int BLOCK_SIZE>
__device__ void BlockReverse(const T* idata, T* odata, int src_base,
                             int dst_base, int valid_item) {
  __shared__ T sh_mem[BLOCK_SIZE];
  int tx = threadIdx.x;

  int offset = tx;
  int in_index = src_base + offset;
  if (offset >= valid_item) {
    sh_mem[offset] = 0;
  } else {
    int sh_mem_index = BLOCK_SIZE - offset - 1;
    T data = idata[in_index];
    sh_mem[sh_mem_index] = data;
  }

  __syncthreads();
  int out_index = dst_base - offset;
  if (offset < valid_item) {
    int sh_mem_index = BLOCK_SIZE - offset - 1;
    odata[out_index] = sh_mem[sh_mem_index];
56 57 58
  }
}

W
wangchaochaohu 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
template <typename T>
__global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data,
                                 int reverse_size, int outer_size,
                                 int inner_size) {
  int bx = blockIdx.x;
  int by = blockIdx.y;
  int item_per_block = 1024;

  for (int block_offset = 0; block_offset < reverse_size;
       block_offset += item_per_block) {
    int valid_item = (reverse_size - block_offset > item_per_block)
                         ? item_per_block
                         : reverse_size - block_offset;
    int src_offset =
        bx * reverse_size + block_offset + by * (inner_size * reverse_size);
    int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) +
                     reverse_size - 1 - block_offset;
    if (reverse_size < item_per_block) {
      valid_item = reverse_size;
78 79
    }

W
wangchaochaohu 已提交
80 81
    BlockReverse<T, 1024>(matrix_data, reverse_data, src_offset, dst_offset,
                          valid_item);
82 83 84 85
  }
}

template <typename T>
W
wangchaochaohu 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98
struct BlockPrefixCallbackOp {
  // Running prefix
  T running_total;
  // Constructor
  __device__ BlockPrefixCallbackOp(T running_total)
      : running_total(running_total) {}
  // Callback operator to be entered by the first warp of threads in the block.
  // Thread-0 is responsible for returning a value for seeding the block-wide
  // scan.
  __device__ T operator()(T block_aggregate) {
    T old_prefix = running_total;
    running_total = old_prefix + block_aggregate;
    return old_prefix;
99
  }
W
wangchaochaohu 已提交
100
};
101

W
wangchaochaohu 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114
// No bank-conflict transpose
template <typename T, int TILE_DIM, int BLOCK_ROWS>
__global__ void MatrixTranspose(T* odata, const T* idata, size_t height,
                                size_t width) {
  __shared__ T tile[TILE_DIM][TILE_DIM + 1];

  int x = blockIdx.x * TILE_DIM + threadIdx.x;
  int y = blockIdx.y * TILE_DIM + threadIdx.y;
  for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
    if (x < width && (y + j) < height) {
      tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x];
    } else {
      tile[threadIdx.y + j][threadIdx.x] = 0;
115 116
    }
  }
W
wangchaochaohu 已提交
117

118 119
  __syncthreads();

W
wangchaochaohu 已提交
120 121
  x = blockIdx.y * TILE_DIM + threadIdx.x;  // transpose block offset
  y = blockIdx.x * TILE_DIM + threadIdx.y;
E
emailweixu 已提交
122

W
wangchaochaohu 已提交
123 124 125
  for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
    if (x < height && (y + j) < width) {
      odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j];
126 127
    }
  }
W
wangchaochaohu 已提交
128
}
129

W
wangchaochaohu 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD>
__global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size,
                                int outer_size, int scan_size, bool exclusive) {
  // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
  typedef cub::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD,
                         cub::BLOCK_LOAD_TRANSPOSE>
      BlockLoadT;
  typedef cub::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD,
                          cub::BLOCK_STORE_TRANSPOSE>
      BlockStoreT;
  typedef cub::BlockScan<T, BLOCK_THREADS> BlockScanT;
  // Allocate type-safe, repurposable shared memory for collectives
  __shared__ union {
    typename BlockLoadT::TempStorage load;
    typename BlockStoreT::TempStorage store;
    typename BlockScanT::TempStorage scan;
  } temp_storage;

  int bx = blockIdx.x;
  int by = blockIdx.y;

  BlockPrefixCallbackOp<T> prefix_op(0);
  T block_aggregate = static_cast<T>(0);

  // Obtain this block's segment of consecutive keys (blocked across threads)
  int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
  for (int block_offset = 0; block_offset < scan_size;
       block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
    int valid_item = (scan_size - block_offset > item_per_block)
                         ? item_per_block
                         : (scan_size - block_offset);
    if (scan_size < item_per_block) {
      valid_item = scan_size;
    }
164

W
wangchaochaohu 已提交
165
    int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
166

W
wangchaochaohu 已提交
167 168 169
    T thread_keys[ITEMS_PER_THREAD];
    BlockLoadT(temp_storage.load)
        .Load(d_in + offset, thread_keys, valid_item, 0);
170

W
wangchaochaohu 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183
    __syncthreads();
    if (exclusive) {
      T init_value = static_cast<T>(0);
      BlockScanT(temp_storage.scan)
          .ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op);
    } else {
      BlockScanT(temp_storage.scan)
          .InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op);
    }
    __syncthreads();

    BlockStoreT(temp_storage.store)
        .Store(d_out + offset, thread_keys, valid_item);
184 185 186 187 188 189 190 191 192 193 194 195 196
  }
}

template <typename DeviceContext, typename T>
class CumCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<framework::Tensor>("X");
    auto* out = context.Output<framework::Tensor>("Out");

    int axis = context.Attr<int>("axis");
    bool exclusive = context.Attr<bool>("exclusive");
    bool reverse = context.Attr<bool>("reverse");
197
    auto out_dims = out->dims();
198 199
    auto size = in->numel();

200 201 202 203 204 205 206 207
    PADDLE_ENFORCE_EQ(
        axis < out_dims.size() && axis >= (0 - out_dims.size()), true,
        platform::errors::OutOfRange(
            "Attr(axis) is out of range, It's expected "
            "to be in range of [-%d, %d]. But received Attr(axis) = %d.",
            out_dims.size(), out_dims.size() - 1, axis));
    if (axis < 0) {
      axis += out_dims.size();
208
    }
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239

    T* out_data = out->mutable_data<T>(context.GetPlace());
    const T* in_data = in->data<T>();

    // Use thrust for parallel acceleration when the input size is equal to the
    // length of the ‘axis’ dimension.
    if (size == out_dims[axis]) {
      if (reverse) {
        thrust::device_ptr<const T> dev_ptr =
            thrust::device_pointer_cast(in_data);
        thrust::device_vector<T> vec(dev_ptr, dev_ptr + size);
        if (exclusive) {
          thrust::exclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
                                 out_data);
        } else {
          thrust::inclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
                                 out_data);
        }
        thrust::reverse(thrust::device, out_data, out_data + size);
      } else {
        if (exclusive) {
          thrust::exclusive_scan(thrust::device, in_data, in_data + size,
                                 out_data);
        } else {
          thrust::inclusive_scan(thrust::device, in_data, in_data + size,
                                 out_data);
        }
      }
      return;
    }

W
wangchaochaohu 已提交
240 241 242 243
    size_t height = 1;
    size_t width = 1;
    for (size_t i = 0; i <= axis; i++) {
      height *= out_dims[i];
244
    }
W
wangchaochaohu 已提交
245

246
    for (size_t i = axis + 1; i < out_dims.size(); i++) {
W
wangchaochaohu 已提交
247
      width *= out_dims[i];
248
    }
W
wangchaochaohu 已提交
249 250
    int scan_size = out_dims[axis];
    bool transpose = (axis != out_dims.size() - 1);
251

W
wangchaochaohu 已提交
252 253 254 255
    int tile_size = 32;
    dim3 blocks(32, 8);
    dim3 transpose_grids((width + tile_size - 1) / tile_size,
                         (height + tile_size - 1) / tile_size);
256
    auto& dev_ctx = context.template device_context<DeviceContext>();
257
    framework::Tensor tmp;
W
wangchaochaohu 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    tmp.Resize(out_dims);
    auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
    T* next_in_data = out_data;
    T* next_out_data = tmp_data;
    if (transpose) {
      MatrixTranspose<T, 32,
                      8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
          out_data, in_data, height, width);
      next_in_data = out_data;
      next_out_data = tmp_data;
    }
    auto swap_ptr = [](T*& ptr1, T*& ptr2) {
      T* tmp = ptr2;
      ptr2 = ptr1;
      ptr1 = tmp;
    };
    int outer_size = height / scan_size;
    int inner_size = width;
    // Consider the size of shared memory, here block size is 128
    dim3 scan_grid(outer_size, inner_size);
    dim3 reverse_grid = scan_grid;
    if (reverse) {
      if (transpose) {
        reverse_grid.x = scan_grid.y;
        reverse_grid.y = scan_grid.x;
        MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
            next_in_data, next_out_data, scan_size, outer_size, inner_size);
        if (!transpose) next_in_data = tmp_data;
        swap_ptr(next_in_data, next_out_data);
287
      } else {
W
wangchaochaohu 已提交
288 289
        MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
            in_data, out_data, scan_size, outer_size, inner_size);
290
      }
W
wangchaochaohu 已提交
291 292 293 294 295
    }
    if (!transpose && !reverse) {
      BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
          out_data, in_data, outer_size, inner_size, scan_size, exclusive);

296
    } else {
W
wangchaochaohu 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
      BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
          next_out_data, next_in_data, outer_size, inner_size, scan_size,
          exclusive);
    }
    swap_ptr(next_in_data, next_out_data);
    if (reverse) {
      MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
          next_in_data, next_out_data, scan_size, outer_size, inner_size);
      swap_ptr(next_in_data, next_out_data);
    }
    if (transpose) {
      transpose_grids.x = (height + tile_size - 1) / tile_size;
      transpose_grids.y = (width + tile_size - 1) / tile_size;
      MatrixTranspose<T, 32,
                      8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
          next_out_data, next_in_data, width, height);
313 314 315 316 317 318 319 320 321 322
    }
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    cumsum, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, double>,
323
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int16_t>,
324 325
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);