for_range.h 3.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yang Yu 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yang Yu 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yang Yu 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yang Yu 已提交
14 15

#pragma once
Q
Qiao Longfei 已提交
16 17 18 19

#include <vector>

#include "paddle/fluid/framework/mixed_vector.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/device_context.h"
Y
Yang Yu 已提交
21 22 23 24

namespace paddle {
namespace platform {

Q
Qiao Longfei 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
template <typename DeviceContext>
struct ForRangeIn {
  ForRangeIn(const DeviceContext& dev_ctx, std::vector<int64_t> range);

  template <typename Function>
  void operator()(Function func) const;
};

template <>
struct ForRangeIn<CPUDeviceContext> {
  ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector<int64_t> range)
      : range_(range) {}

  template <typename Function>
  void operator()(Function func) const {
    for (auto i : range_) {
      func(i);
    }
  }

  std::vector<int64_t> range_;
};

Y
Yang Yu 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
template <typename DeviceContext>
struct ForRange {
  ForRange(const DeviceContext& dev_ctx, size_t limit);

  template <typename Function>
  void operator()(Function func) const;
};

template <>
struct ForRange<CPUDeviceContext> {
  ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}

  template <typename Function>
  void operator()(Function func) const {
    for (size_t i = 0; i < limit_; ++i) {
      func(i);
    }
  }

  size_t limit_;
};

#ifdef __NVCC__
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
  size_t idx = static_cast<size_t>(threadIdx.x);
  func(idx);
}

template <typename Function>
78
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
Y
Yang Yu 已提交
79 80 81 82 83 84 85 86 87
  size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
  if (idx < limit) {
    func(idx);
  }
}

template <>
struct ForRange<CUDADeviceContext> {
  ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
88
      : dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}
Y
Yang Yu 已提交
89 90 91

  template <typename Function>
  inline void operator()(Function func) const {
L
Luo Tao 已提交
92
    constexpr int num_threads = 1024;
Y
Yang Yu 已提交
93
    int block_size = limit_ <= num_threads ? limit_ : num_threads;
94 95 96 97 98
    int grid_size = (limit_ + num_threads - 1) / num_threads;

    if (grid_size == 1) {
      ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
          func);
Y
Yang Yu 已提交
99
    } else {
100 101
      ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
          func, limit_);
Y
Yang Yu 已提交
102 103 104 105
    }
  }

  const CUDADeviceContext& dev_ctx_;
106
  int limit_;
Y
Yang Yu 已提交
107 108
};

Q
Qiao Longfei 已提交
109 110 111 112 113 114 115 116 117 118 119
template <typename T, typename Function>
__global__ static void ForRangeInElemwiseOp(Function func, T* vector,
                                            int vector_size) {
  size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
  if (idx < vector_size) {
    func(vector[idx]);
  }
}

template <>
struct ForRangeIn<CUDADeviceContext> {
Q
Qiao Longfei 已提交
120
  ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
Q
Qiao Longfei 已提交
121 122 123 124 125
      : dev_ctx_(dev_ctx), range_(range) {}

  template <typename Function>
  inline void operator()(Function func) const {
    constexpr int num_threads = 1024;
Q
Qiao Longfei 已提交
126 127
    int range_size = range_.size();
    int block_size = range_size <= num_threads ? range_size : num_threads;
Q
Qiao Longfei 已提交
128 129 130
    int grid_size = (range_.size() + num_threads - 1) / num_threads;

    ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
Q
Qiao Longfei 已提交
131
        func, range_.data(), range_size);
Q
Qiao Longfei 已提交
132 133 134 135 136 137
  }

  const CUDADeviceContext& dev_ctx_;
  framework::Vector<int64_t> range_;
};

Y
Yang Yu 已提交
138 139 140 141
#endif

}  // namespace platform
}  // namespace paddle