“6b4a51bae3e8bdb1266573d67adde9fb55cf86b6”上不存在“paddle/fluid/operators/ascend_trigger_op.cc”
gpu_launch_config.h 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14

15
// Used for compute gpu launch parameter config
16 17 18

#pragma once

19
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
20

21
#ifdef PADDLE_WITH_CUDA
22
#include <cuda_runtime.h>
23 24 25
#else
#include <hip/hip_runtime.h>
#endif
F
feng_shuai 已提交
26

27 28 29 30
#include <stddef.h>
#include <algorithm>
#include <string>
#include <vector>
31
#include "paddle/fluid/platform/device_context.h"
32

33 34 35 36 37 38 39 40 41
#ifdef __HIPCC__
// HIP results in error or nan if > 256
#define PREDEFINED_BLOCK_SIZE 256
#else
/* CUDA performs better as thread_per_block
   num is between [64, 512] */
#define PREDEFINED_BLOCK_SIZE 512
#endif

42 43 44
namespace paddle {
namespace platform {

45
inline int DivUp(int a, int b) { return (a + b - 1) / b; }
46

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
/* https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
   for round integer value into next highest power of 2. */
static inline int RoundToPowerOfTwo(int n) {
  n--;
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
#ifdef __HIPCC__
  return std::min(256, std::max(32, (n + 1)));
#else
  return std::min(1024, std::max(32, (n + 1)));
#endif
}

F
feng_shuai 已提交
63 64 65
#ifdef WITH_NV_JETSON
// The number of threads cannot be assigned 1024 in some cases when the device
// is nano or tx2 .
W
Wilber 已提交
66 67 68
template <typename CUDADeviceContext>
inline void ChangeThreadNum(const CUDADeviceContext& context, int* num_thread,
                            int alternative_num_thread = 512) {
F
feng_shuai 已提交
69 70 71 72 73 74 75
  if (context.GetComputeCapability() == 53 ||
      context.GetComputeCapability() == 62) {
    *num_thread = alternative_num_thread;
  }
}
#endif

76
struct GpuLaunchConfig {
77 78 79 80 81 82 83 84 85 86 87 88 89 90
 public:
  GpuLaunchConfig() {}

  size_t GetThreadNum() const { return GetBlockSize() * GetGridSize(); }

  size_t GetGridSize() const {
    return block_per_grid.x * block_per_grid.y * block_per_grid.z;
  }

  size_t GetBlockSize() const {
    return thread_per_block.x * thread_per_block.y * thread_per_block.z;
  }

  int compute_capability = 0;
91 92
  dim3 thread_per_block = dim3(1, 1, 1);
  dim3 block_per_grid = dim3(1, 1, 1);
93 94
};

95 96 97 98
/* According to NVIDIA, if number of threads per block is 64/128/256/512,
  * cuda performs better. And number of blocks should be greater (at least
  * 2x~4x) than number of SMs. Hence, SM count is took into account within
  * this function to determine the right number of threads per block. */
99
inline GpuLaunchConfig GetGpuLaunchConfig1D(
100 101 102 103 104 105
    const platform::CUDADeviceContext& context, int64_t numel,
    int vec_size = 1) {
  PADDLE_ENFORCE_GT(numel, 0, platform::errors::InvalidArgument(
                                  "element quantity should be greater than 0,"
                                  " but received value is: %d.",
                                  numel));
F
feng_shuai 已提交
106 107
  // Get compute_capability
  const int capability = context.GetComputeCapability();
108 109 110
  /* If thread number per block is 64/128/256/512, cuda performs better.*/
  int limit_threads =
      std::min(PREDEFINED_BLOCK_SIZE, context.GetMaxThreadsPerBlock());
F
feng_shuai 已提交
111 112
#ifdef WITH_NV_JETSON
  if (capability == 53 || capability == 62) {
113
    limit_threads = 512;
F
feng_shuai 已提交
114 115
  }
#endif
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  int threads = limit_threads;
  int sm_count = context.GetSMCount();
  int active_threads_num = numel / vec_size;
  if (active_threads_num / (sm_count << 1) < limit_threads) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about twice of SM, to acquire better performance.
    threads = RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
  } else if (active_threads_num / (sm_count << 2) < limit_threads) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about 4 times of SM, to acquire better performance.
    threads = RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
  }
  // Number of threads per block shall be larger than 64.
  threads = std::max(64, threads);
  int blocks = DivUp(DivUp(numel, vec_size), threads);
131 132 133 134
  int limit_blocks = context.GetCUDAMaxGridDimSize()[0];
  if (blocks > limit_blocks) {
    blocks = limit_blocks;
  }
135

136
  GpuLaunchConfig config;
137 138
  config.thread_per_block.x = threads;
  config.block_per_grid.x = blocks;
139
  config.compute_capability = capability;
140 141 142 143
  return config;
}

inline GpuLaunchConfig GetGpuLaunchConfig2D(
144 145 146 147 148 149 150 151 152
    const platform::CUDADeviceContext& context, int x_dim, int y_dim) {
  PADDLE_ENFORCE_GT(x_dim, 0, platform::errors::InvalidArgument(
                                  "x dim number should greater than 0,"
                                  " but received value is: %d",
                                  x_dim));
  PADDLE_ENFORCE_GT(y_dim, 0, platform::errors::InvalidArgument(
                                  "y dim number should greater than 0,"
                                  " but received value is: %d",
                                  y_dim));
153 154

  const int kThreadsPerBlock = 256;
F
feng_shuai 已提交
155 156
  int block_cols = (std::min)(x_dim, kThreadsPerBlock);
  int block_rows = (std::max)(kThreadsPerBlock / block_cols, 1);
157 158

  int max_physical_threads = context.GetMaxPhysicalThreadCount();
F
feng_shuai 已提交
159
  const int max_blocks = (std::max)(max_physical_threads / kThreadsPerBlock, 1);
160

161 162 163 164
  GpuLaunchConfig config;
  // Noticed, block size is not align to 32, if needed do it yourself.
  config.thread_per_block = dim3(block_cols, block_rows, 1);

F
feng_shuai 已提交
165 166 167
  int grid_x = (std::min)(DivUp(x_dim, block_cols), max_blocks);
  int grid_y =
      (std::min)(max_blocks / grid_x, (std::max)(y_dim / block_rows, 1));
168 169

  config.block_per_grid = dim3(grid_x, grid_y, 1);
170 171 172 173 174
  return config;
}

}  // namespace platform
}  // namespace paddle
175 176

#endif