for_range.h 2.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yang Yu 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yang Yu 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yang Yu 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yang Yu 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16
#include "paddle/fluid/platform/device_context.h"
Y
Yang Yu 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

namespace paddle {
namespace platform {

template <typename DeviceContext>
struct ForRange {
  ForRange(const DeviceContext& dev_ctx, size_t limit);

  template <typename Function>
  void operator()(Function func) const;
};

template <>
struct ForRange<CPUDeviceContext> {
  ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}

  template <typename Function>
  void operator()(Function func) const {
    for (size_t i = 0; i < limit_; ++i) {
      func(i);
    }
  }

  size_t limit_;
};

43
#if defined(__NVCC__) || defined(__HIPCC__)
Y
Yang Yu 已提交
44 45 46 47 48 49 50
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
  size_t idx = static_cast<size_t>(threadIdx.x);
  func(idx);
}

template <typename Function>
51
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
Y
Yang Yu 已提交
52 53 54 55 56 57 58 59 60
  size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
  if (idx < limit) {
    func(idx);
  }
}

template <>
struct ForRange<CUDADeviceContext> {
  ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
61
      : dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}
Y
Yang Yu 已提交
62 63 64

  template <typename Function>
  inline void operator()(Function func) const {
L
Luo Tao 已提交
65
    constexpr int num_threads = 1024;
Y
Yang Yu 已提交
66
    int block_size = limit_ <= num_threads ? limit_ : num_threads;
67 68 69 70 71
    int grid_size = (limit_ + num_threads - 1) / num_threads;

    if (grid_size == 1) {
      ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
          func);
Y
Yang Yu 已提交
72
    } else {
73 74
      ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
          func, limit_);
Y
Yang Yu 已提交
75 76 77 78
    }
  }

  const CUDADeviceContext& dev_ctx_;
79
  int limit_;
Y
Yang Yu 已提交
80 81 82 83 84 85
};

#endif

}  // namespace platform
}  // namespace paddle