未验证 提交 463c72c2 编写于 作者: W wangchaochaohu 提交者: GitHub

refine gpu kernel config for Paddle (#28085)

上级 2cb1ecb9
......@@ -70,7 +70,7 @@ class BCELossCUDAKernel : public framework::OpKernel<T> {
auto x_numel = x->numel();
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(x_numel, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), x_numel);
Tensor x_cpu;
framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu);
......@@ -89,9 +89,9 @@ class BCELossCUDAKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.cuda_device_context();
GPUBCELossForward<
T><<<config.blocks, config.threads, 0, dev_ctx.stream()>>>(
x_data, labels->data<T>(), out_data, x_numel);
GPUBCELossForward<T><<<config.block_per_grid, config.thread_per_block, 0,
dev_ctx.stream()>>>(x_data, labels->data<T>(),
out_data, x_numel);
}
};
......@@ -106,12 +106,12 @@ class BCELossGradCUDAKernel : public framework::OpKernel<T> {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel();
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(x_numel, ctx);
auto& dev_ctx = ctx.cuda_device_context();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, x_numel);
GPUBCELossBackward<
T><<<config.blocks, config.threads, 0, dev_ctx.stream()>>>(
GPUBCELossBackward<T><<<config.block_per_grid, config.thread_per_block, 0,
dev_ctx.stream()>>>(
x->data<T>(), labels->data<T>(), dout->data<T>(), dx_data, x_numel);
}
};
......
......@@ -165,10 +165,11 @@ class BilateralSliceOpCUDAKernel : public framework::OpKernel<T> {
int total_count = batch_size * h * w * output_dims[1];
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(total_count, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), total_count);
BilateralSliceCudaForwardKernel<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
BilateralSliceCudaForwardKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
output_data, grid_data, guide_data, input_data, grid_sizes, has_offset,
total_count, output_dims[1]);
}
......@@ -472,24 +473,29 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
grid_sizes.input_chans = input_chans;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(grid_count, ctx, 512);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), grid_count);
BilateralSliceCudaGridGradKernel<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
BilateralSliceCudaGridGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes,
has_offset, grid_count, output_chans);
config = platform::getGpuLaunchConfig(guide_count, ctx, 512);
config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), guide_count);
BilateralSliceCudaGuideGradKernel<T><<<
config.blocks, config.threads, 0, ctx.cuda_device_context().stream()>>>(
BilateralSliceCudaGuideGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
guide_grad_data, output_grad_data, grid_data, guide_data, input_data,
grid_sizes, has_offset, guide_count, output_chans);
config = platform::getGpuLaunchConfig(input_count, ctx, 512);
config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_count);
BilateralSliceCudaInputGradKernel<T><<<
config.blocks, config.threads, 0, ctx.cuda_device_context().stream()>>>(
BilateralSliceCudaInputGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, output_grad_data, grid_data, guide_data, grid_sizes,
has_offset, input_count, output_chans);
}
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
using Tensor = paddle::framework::Tensor;
using LoDTensor = paddle::framework::LoDTensor;
......
......@@ -887,10 +887,10 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("linear" == interp_method) {
KeLinearInterpFw<T><<<config.blocks, config.threads, 0,
KeLinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w,
align_corners, align_mode, data_layout);
......@@ -981,21 +981,22 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<T><<<config.blocks, config.threads, 0,
KeBilinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpFw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
KeBicubicInterpFw<T><<<config.block_per_grid, 512, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
......@@ -1097,10 +1098,10 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<T><<<config.blocks, config.threads, 0,
KeTrilinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
......@@ -1176,10 +1177,10 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("linear" == interp_method) {
KeLinearInterpBw<T><<<config.blocks, config.threads, 0,
KeLinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c,
ratio_w, align_corners, align_mode, data_layout);
......@@ -1267,22 +1268,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
KeNearestNeighborInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<T><<<config.blocks, config.threads, 0,
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
KeBicubicInterpBw<T><<<config.block_per_grid, 512, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
......@@ -1378,10 +1380,10 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<T><<<config.blocks, config.threads, 0,
KeTrilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
......
......@@ -899,10 +899,10 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("linear" == interp_method) {
KeLinearInterpFw<T><<<config.blocks, config.threads, 0,
KeLinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w,
align_corners, align_mode, data_layout);
......@@ -1018,21 +1018,22 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<T><<<config.blocks, config.threads, 0,
KeBilinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpFw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
KeBicubicInterpFw<T><<<config.block_per_grid, 512, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
......@@ -1167,10 +1168,10 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<T><<<config.blocks, config.threads, 0,
KeTrilinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
......@@ -1259,10 +1260,10 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("linear" == interp_method) {
KeLinearInterpBw<T><<<config.blocks, config.threads, 0,
KeLinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c,
ratio_w, align_corners, align_mode, data_layout);
......@@ -1376,22 +1377,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
KeNearestNeighborInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<T><<<config.blocks, config.threads, 0,
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
KeBicubicInterpBw<T><<<config.block_per_grid, 512, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
......@@ -1520,10 +1522,10 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
int pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<T><<<config.blocks, config.threads, 0,
KeTrilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
......
......@@ -87,8 +87,9 @@ class MishCUDAKernel : public framework::OpKernel<T> {
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishFw<T><<<config.blocks, config.threads, 0,
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(x_data, out_data, numel,
threshold);
}
......@@ -108,8 +109,9 @@ class MishFP32CUDAKernel : public framework::OpKernel<float> {
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishFwFP32<<<config.blocks, config.threads, 0,
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishFwFP32<<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(x_data, out_data,
numel, threshold);
}
......@@ -131,8 +133,9 @@ class MishGradCUDAKernel : public framework::OpKernel<T> {
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishBw<T><<<config.blocks, config.threads, 0,
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
x_data, dout_data, dx_data, numel, threshold);
}
......@@ -154,8 +157,9 @@ class MishGradFP32CUDAKernel : public framework::OpKernel<float> {
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishBwFP32<<<config.blocks, config.threads, 0,
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishBwFP32<<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
x_data, dout_data, dx_data, numel, threshold);
}
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mv_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/segment_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......
......@@ -16,7 +16,7 @@
#include <limits>
#include <vector>
#include "paddle/fluid/operators/stack_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace platform = paddle::platform;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Used for compute gpu launch parameter
#pragma once
#include <algorithm>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_primitives.h"
#include <cuda_runtime.h>
#include <stddef.h>
#include <algorithm>
#include <string>
#include <vector>
namespace paddle {
namespace platform {
struct GpuLaunchConfig {
// Number of threads per block.
int threads;
// Number of blocks for GPU kernel launch.
int blocks;
inline int DivUp(int a, int b) { return (a + b - 1) / b; }
GpuLaunchConfig(int threads, int blocks) : threads(threads), blocks(blocks) {}
struct GpuLaunchConfig {
dim3 theory_thread_count = dim3(1, 1, 1);
dim3 thread_per_block = dim3(1, 1, 1);
dim3 block_per_grid = dim3(1, 1, 1);
};
inline GpuLaunchConfig getGpuLaunchConfig(
const int N, const framework::ExecutionContext& ctx,
int max_threads = 1024) {
int threads =
std::min(max_threads, ctx.cuda_device_context().GetMaxThreadsPerBlock());
int physical_thread_count =
std::min(ctx.cuda_device_context().GetMaxPhysicalThreadCount(), N);
int blocks = std::min((physical_thread_count + threads - 1) / threads,
ctx.cuda_device_context().GetSMCount());
inline GpuLaunchConfig GetGpuLaunchConfig1D(
const platform::CUDADeviceContext& context, int element_count) {
PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument(
"element count should greater than 0,"
" but received value is %d.",
element_count));
const int theory_thread_count = element_count;
// Get Max threads in all SM
int max_pyhsical_threads = context.GetMaxPhysicalThreadCount();
int sm = context.GetSMCount();
// Compute pyhsical threads we need, should small than max sm threads
const int physical_thread_count =
std::min(max_pyhsical_threads, theory_thread_count);
// Need get from device
const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock());
// Suppose block count small than factor * sm, factor is a experiments value.
int factor = 4;
const int block_count =
std::min(DivUp(physical_thread_count, thread_per_block), factor * sm);
GpuLaunchConfig config(threads, blocks);
GpuLaunchConfig config;
config.theory_thread_count.x = theory_thread_count;
config.thread_per_block.x = thread_per_block;
config.block_per_grid.x = block_count;
return config;
}
inline GpuLaunchConfig GetGpuLaunchConfig2D(
const platform::CUDADeviceContext& context, int xdim, int ydim) {
PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument(
"x dim number should greater than 0,"
" but received value is:%d",
xdim));
PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument(
"y dim number should greater than 0,"
" but received value is:%d",
ydim));
const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
int max_physical_threads = context.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1);
GpuLaunchConfig config;
// Noticed, block size is not align to 32, if needed do it yourself.
config.theory_thread_count = dim3(xdim, ydim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1));
config.block_per_grid = dim3(grid_x, grid_y, 1);
return config;
}
// TODO(wangchaochaohu): 3D will add later
} // namespace platform
} // namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Used for compute gpu launch parameter
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#include <stddef.h>
#include <algorithm>
#include <string>
#include <vector>
namespace paddle {
namespace platform {
inline int DivUp(int a, int b) { return (a + b - 1) / b; }
struct GpuLaunchParamConfig {
dim3 theory_thread_count = dim3(0, 0, 0);
dim3 thread_per_block = dim3(0, 0, 0);
dim3 block_per_grid = dim3(0, 0, 0);
};
inline GpuLaunchParamConfig GetGpuLaunchConfig1D(
const platform::CUDADeviceContext& context, int element_count) {
PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument(
"element count should greater than 0,"
" but received value is %d.",
element_count));
const int theory_thread_count = element_count;
// Get Max threads in all SM
int max_pyhsical_threads = context.GetMaxPhysicalThreadCount();
int sm = context.GetSMCount();
// Compute pyhsical threads we need, should small than max sm threads
const int physical_thread_count =
std::min(max_pyhsical_threads, theory_thread_count);
// Need get from device
const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock());
// Suppose block count small than factor * sm, factor is a experiments value.
int factor = 4;
const int block_count =
std::min(DivUp(physical_thread_count, thread_per_block), factor * sm);
GpuLaunchParamConfig config;
config.theory_thread_count.x = theory_thread_count;
config.thread_per_block.x = thread_per_block;
config.block_per_grid.x = block_count;
return config;
}
inline GpuLaunchParamConfig GetGpuLaunchConfig2D(
const platform::CUDADeviceContext& context, int xdim, int ydim) {
PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument(
"x dim number should greater than 0,"
" but received value is:%d",
xdim));
PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument(
"y dim number should greater than 0,"
" but received value is:%d",
ydim));
const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
int max_physical_threads = context.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1);
GpuLaunchParamConfig config;
// Noticed, block size is not align to 32, if needed do it yourself.
config.theory_thread_count = dim3(xdim, ydim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1));
config.block_per_grid = dim3(grid_x, grid_y, 1);
return config;
}
// 3D will add later
} // namespace platform
} // namespace paddle
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册