cumsum_op.cu 14.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
E
emailweixu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16 17 18
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
Y
Yi Wang 已提交
19
#include "paddle/fluid/operators/cum_op.h"
20
#include "paddle/fluid/platform/gpu_launch_param_config.h"
E
emailweixu 已提交
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
using Tensor = paddle::framework::Tensor;
using LoDTensor = paddle::framework::LoDTensor;

namespace paddle {
namespace operators {

template <typename T>
__global__ void OuterScan(const T* in, T* out, int inner_dim_size,
                          int outer_dim_size, int scan_dim_size, bool exclusive,
                          bool reverse) {
  int id = blockIdx.y * blockDim.x + threadIdx.x;

  for (int outer_index = blockIdx.x; outer_index < outer_dim_size;
       outer_index += gridDim.x) {
    for (int inner_index = blockIdx.y * blockDim.x + threadIdx.x;
         inner_index < inner_dim_size; inner_index += gridDim.y * blockDim.x) {
      int scan_index_init = 0;
      int forward_direction = 1;
      int src_index =
          outer_index * scan_dim_size * inner_dim_size + inner_index;
      int dst_index =
          outer_index * scan_dim_size * inner_dim_size + inner_index;
      if (reverse) {
        src_index = src_index + (scan_dim_size - 1) * inner_dim_size;
        dst_index = dst_index + (scan_dim_size - 1) * inner_dim_size;
        forward_direction = -1;
      }
      if (exclusive) {
        scan_index_init = 1;
        out[dst_index] = 0;
        dst_index = dst_index + (forward_direction * inner_dim_size);
      }
      T acc = 0;

      for (int scan_index = scan_index_init; scan_index < scan_dim_size;
           ++scan_index) {
        acc = in[src_index] + acc;
        out[dst_index] = acc;
        src_index += (forward_direction * inner_dim_size);
        dst_index += (forward_direction * inner_dim_size);
      }
    }
  }
}

// inclusive scan
template <typename T, int num_threads_x, int num_threads_y>
__global__ void InnerMostDimInclusiveScan(const T* in, T* out,
                                          int inner_dim_size,
                                          int outer_dim_size, int scan_dim_size,
                                          bool reverse) {
  __shared__ T share_data[num_threads_y][num_threads_x * 2];
  T* share_row = share_data[threadIdx.y];
  int forward_direction = 1;
  if (reverse) forward_direction = -1;

  for (int block_row = blockIdx.x * blockDim.y; block_row < outer_dim_size;
       block_row += blockDim.y * gridDim.x) {
    int row = block_row + threadIdx.y;
    T acc = 0;
    const T* row_src = in + row * scan_dim_size;
    T* row_dst = out + row * scan_dim_size;
    int block_col = 0;
    bool loop_condition = (block_col < scan_dim_size);
    if (reverse) {
      loop_condition = (block_col >= 0);
      block_col = scan_dim_size - 1;
    }
    while (loop_condition) {
      // Load data into share memory(two value per thread)
      int col1 = block_col + threadIdx.x * forward_direction;
      int col2 = block_col + (num_threads_x + threadIdx.x) * forward_direction;
      if (row < outer_dim_size) {
        if (col1 < scan_dim_size && col1 >= 0) {
          share_row[threadIdx.x] = row_src[col1];
        } else {
          share_row[threadIdx.x] = 0;
        }

        if (col2 < scan_dim_size && col2 >= 0) {
          share_row[num_threads_x + threadIdx.x] = row_src[col2];
        } else {
          share_row[num_threads_x + threadIdx.x] = 0;
        }

        // Add the previous block acc to the result
        if (threadIdx.x == 0) {
          share_row[0] = share_row[0] + acc;
        }
      }
      __syncthreads();

      // Up-Sweep
      for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
        if (row < outer_dim_size && threadIdx.x < s) {
          unsigned offset = (2 * threadIdx.x + 1) * d - 1;
          share_row[offset + d] = share_row[offset] + share_row[offset + d];
        }
        __syncthreads();
      }
      // Down-Sweep
      for (unsigned s = 2, d = blockDim.x / 2; d >= 1; s <<= 1, d >>= 1) {
        if (row < outer_dim_size && threadIdx.x < s - 1) {
          unsigned offset = 2 * (threadIdx.x + 1) * d - 1;
          share_row[offset + d] = share_row[offset] + share_row[offset + d];
        }
        __syncthreads();
      }

      // Write to the output
      if (row < outer_dim_size) {
        if (col1 < scan_dim_size && col1 >= 0)
          row_dst[col1] = share_row[threadIdx.x];
        if (col2 < scan_dim_size && col2 >= 0)
          row_dst[col2] = share_row[num_threads_x + threadIdx.x];
      }
      acc = share_row[2 * num_threads_x - 1];
      __syncthreads();
      block_col += 2 * num_threads_x * forward_direction;
      if (reverse)
        loop_condition = (block_col >= 0);
      else
        loop_condition = (block_col < scan_dim_size);
    }
  }
}

// exclusive block scan and store block sum for large scan
template <typename T>
__global__ void InnerMostDimExclusiveScan(const T* in, T* out, T* sum_data,
                                          int inner_dim_size,
                                          int outer_dim_size, int scan_dim_size,
                                          int two_power, bool reverse) {
  // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
  extern __shared__ __align__(sizeof(T)) unsigned char raw_tmp[];
  T* share_tmp = reinterpret_cast<T*>(raw_tmp);
  int thread_id = threadIdx.x;
  int block_id = blockIdx.x;
  int block_scan_size = blockDim.x * 2;
  int remain = scan_dim_size % (2 * blockDim.x);
  if (block_id == gridDim.x - 1 && remain != 0) block_scan_size = remain;
  int col1 = thread_id;
  int col2 = thread_id + (block_scan_size) / 2;
  int index1 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col1;
  int index2 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col2;
  if (reverse) {
    index1 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 -
             (block_id * blockDim.x * 2 + col1);
    index2 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 -
             (block_id * blockDim.x * 2 + col2);
  }
  int sum_index = blockIdx.y * gridDim.x + block_id;
  if (thread_id < block_scan_size) {
    share_tmp[col1 + (col1 >> 5)] = in[index1];
    share_tmp[col2 + (col2 >> 5)] = in[index2];
  } else {
    share_tmp[col1 + (col1 >> 5)] = 0;
    share_tmp[col2 + (col2 >> 5)] = 0;
  }

  // Up-Sweep
  int offset = 1;
  for (int d = (two_power / 2); d > 0; d >>= 1) {
    __syncthreads();
    if (thread_id < d) {
      int tmp_index1 = offset * (2 * thread_id + 1) - 1;
      int tmp_index2 = offset * (2 * thread_id + 2) - 1;
      tmp_index1 = tmp_index1 + (tmp_index1 >> 5);
      tmp_index2 = tmp_index2 + (tmp_index2 >> 5);

      share_tmp[tmp_index2] += share_tmp[tmp_index1];
    }
    offset *= 2;
  }
  __syncthreads();

  if (thread_id == 0) {
    int tmp_index = (two_power - 1) + ((two_power - 1) >> 5);
    sum_data[sum_index] = share_tmp[tmp_index];
    share_tmp[tmp_index] = 0;
  }
E
emailweixu 已提交
203

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
  // Down Sweep
  for (int d = 1; d < two_power; d *= 2) {
    offset >>= 1;
    __syncthreads();
    if (thread_id < d) {
      int tmp_index1 = offset * (2 * thread_id + 1) - 1;
      int tmp_index2 = offset * (2 * thread_id + 2) - 1;
      tmp_index1 = tmp_index1 + (tmp_index1 >> 5);
      tmp_index2 = tmp_index2 + (tmp_index2 >> 5);

      T tmp = share_tmp[tmp_index1];
      share_tmp[tmp_index1] = share_tmp[tmp_index2];
      share_tmp[tmp_index2] += tmp;
    }
  }

  __syncthreads();

  if (col1 < block_scan_size) out[index1] = share_tmp[col1 + (col1 >> 5)];
  if (col2 < block_scan_size) out[index2] = share_tmp[col2 + (col2 >> 5)];
}

// for large scan_dim_size array we need to add for correct result
template <typename T>
__global__ void AddBlockScan(T* result, T* sum, int size, int scan_dim_size,
                             int sum_size, bool reverse) {
  int idx = threadIdx.x + blockDim.x * (blockIdx.x + blockIdx.y * gridDim.x);
  int block_id_start = blockIdx.y * sum_size;
  int block_id_end = blockIdx.x + blockIdx.y * sum_size;
  int block_id = blockIdx.x;
  int thread_id = threadIdx.x;

  int col = block_id * blockDim.x + thread_id + size;
  int index = blockIdx.y * (scan_dim_size) + col;
  if (reverse) {
    index = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - col;
  }

  if (col >= scan_dim_size || col < 0) return;
  for (int i = block_id_start; i <= block_id_end; i++) {
    result[index] += sum[i];
  }
}

template <typename DeviceContext, typename T>
class CumCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<framework::Tensor>("X");
    auto* out = context.Output<framework::Tensor>("Out");

    int axis = context.Attr<int>("axis");
    bool exclusive = context.Attr<bool>("exclusive");
    bool reverse = context.Attr<bool>("reverse");
258
    auto out_dims = out->dims();
259 260
    auto size = in->numel();

261 262 263 264 265 266 267 268
    PADDLE_ENFORCE_EQ(
        axis < out_dims.size() && axis >= (0 - out_dims.size()), true,
        platform::errors::OutOfRange(
            "Attr(axis) is out of range, It's expected "
            "to be in range of [-%d, %d]. But received Attr(axis) = %d.",
            out_dims.size(), out_dims.size() - 1, axis));
    if (axis < 0) {
      axis += out_dims.size();
269
    }
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302

    T* out_data = out->mutable_data<T>(context.GetPlace());
    const T* in_data = in->data<T>();

    // Use thrust for parallel acceleration when the input size is equal to the
    // length of the ‘axis’ dimension.
    if (size == out_dims[axis]) {
      if (reverse) {
        thrust::device_ptr<const T> dev_ptr =
            thrust::device_pointer_cast(in_data);
        thrust::device_vector<T> vec(dev_ptr, dev_ptr + size);
        if (exclusive) {
          thrust::exclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
                                 out_data);
        } else {
          thrust::inclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
                                 out_data);
        }
        thrust::reverse(thrust::device, out_data, out_data + size);
      } else {
        if (exclusive) {
          thrust::exclusive_scan(thrust::device, in_data, in_data + size,
                                 out_data);
        } else {
          thrust::inclusive_scan(thrust::device, in_data, in_data + size,
                                 out_data);
        }
      }
      return;
    }

    const int& scan_dim_size = out_dims[axis];
    bool optimize_condition = (axis == (out_dims.size() - 1)) ? true : false;
303 304 305 306
    int outer_dim_size = 1;
    int inner_dim_size = 1;
    // treat all dim index < axis as outer_dim_size
    for (size_t i = 0; i < axis; i++) {
307
      outer_dim_size *= out_dims[i];
308 309
    }
    // treat all dim index > axis as innner_dim_size
310 311
    for (size_t i = axis + 1; i < out_dims.size(); i++) {
      inner_dim_size *= out_dims[i];
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
    }

    auto& dev_ctx = context.template device_context<DeviceContext>();
    if (optimize_condition) {
      auto nextPowerOfTwo = [](int x) -> int {
        int ret = 1;
        while (ret < x) ret = ret * 2;
        return ret;
      };
      if (exclusive) {
        int element_per_block = nextPowerOfTwo(scan_dim_size) / 2;
        if (element_per_block > 512 || element_per_block < 32) {
          element_per_block = 64;
        }
        int two_power = element_per_block * 2;
        dim3 block(element_per_block);
        dim3 grid(((scan_dim_size + 1) / 2 + block.x - 1) / block.x,
                  outer_dim_size);
        int offset_size = (element_per_block * 2) >> 5;
        int share_mem_size = (element_per_block * 2 + offset_size) * sizeof(T);
        Tensor scan_sum;
        paddle::framework::DDim dims{
            ((scan_dim_size + 1) / 2 + block.x - 1) / block.x, outer_dim_size};
        scan_sum.Resize(dims);
        T* sum_data = scan_sum.mutable_data<T>(context.GetPlace());
        InnerMostDimExclusiveScan<
            T><<<grid, block, share_mem_size, dev_ctx.stream()>>>(
            in_data, out_data, sum_data, inner_dim_size, outer_dim_size,
            scan_dim_size, two_power, reverse);
        // for large scan array we need to do add for correct result
        int element_size = element_per_block * 2;
        if (scan_dim_size > element_size) {
          dim3 sum_block(element_per_block * 2);
          dim3 sum_grid((scan_dim_size - element_size + block.x - 1) / block.x,
                        outer_dim_size);
          int sum_size = ((scan_dim_size + 1) / 2 + block.x - 1) / block.x;
          AddBlockScan<T><<<sum_grid, sum_block, 0, dev_ctx.stream()>>>(
              out_data, sum_data, element_size, scan_dim_size, sum_size,
              reverse);
        }

      } else {
        dim3 block(32, 16);
        dim3 grid((outer_dim_size + block.y - 1) / block.y);
        InnerMostDimInclusiveScan<T, 32,
                                  16><<<grid, block, 0, dev_ctx.stream()>>>(
            in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size,
            reverse);
      }
    } else {
      dim3 block(std::min(512, inner_dim_size));
      dim3 grid(outer_dim_size, (inner_dim_size + block.x - 1) / block.x);
      OuterScan<T><<<grid, block, 0, dev_ctx.stream()>>>(
          in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size,
          exclusive, reverse);
    }
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    cumsum, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, double>,
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);