提交 fef2faa7 编写于 作者: D dengkaipeng

limit CUDA kernel parallel threads max number to 4096. test=develop

上级 34bfae24
...@@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw( ...@@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw(
const size_t num_channels, const float ratio_h, const float ratio_w) { const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) { int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w; int out_id_h = tid / output_w;
int out_id_w = tid % output_w; int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels; int in_img_size = input_w / num_channels;
...@@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw( ...@@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw(
const size_t num_channels, const float ratio_h, const float ratio_w) { const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) { int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w; int out_id_h = tid / output_w;
int out_id_w = tid % output_w; int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels; int in_img_size = input_w / num_channels;
...@@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw( ...@@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw(
const size_t num_channels, const float ratio_h, const float ratio_w) { const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) { int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w; int out_id_h = tid / output_w;
int out_id_w = tid % output_w; int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels; int in_img_size = input_w / num_channels;
...@@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw( ...@@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw(
const size_t num_channels, const T ratio_h, const T ratio_w) { const size_t num_channels, const T ratio_h, const T ratio_w) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) { int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w; int out_id_h = tid / output_w;
int out_id_w = tid % output_w; int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels; int in_img_size = input_w / num_channels;
...@@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
int threadNum = n * out_chw; int pixelNum = n * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024; int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) { if ("nearest" == interp_method) {
KeNearestNeighborInterpFw< KeNearestNeighborInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w); out_chw, c, ratio_h, ratio_w);
} else if ("bilinear" == interp_method) { } else if ("bilinear" == interp_method) {
KeBilinearInterpFw< KeBilinearInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w); out_chw, c, ratio_h, ratio_w);
} }
...@@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
int threadNum = n * out_chw; int pixelNum = n * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024; int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) { if ("nearest" == interp_method) {
KeNearestNeighborInterpBw< KeNearestNeighborInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w); out_w, n, out_chw, c, ratio_h, ratio_w);
} else if ("bilinear" == interp_method) { } else if ("bilinear" == interp_method) {
KeBilinearInterpBw< KeBilinearInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w); out_w, n, out_chw, c, ratio_h, ratio_w);
} }
......
...@@ -167,13 +167,13 @@ class TestBilinearInterpCase6(TestInterpolateOp): ...@@ -167,13 +167,13 @@ class TestBilinearInterpCase6(TestInterpolateOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
# class TestBilinearInterpBigScale(TestInterpolateOp): class TestBilinearInterpBigScale(TestInterpolateOp):
# def init_test_case(self): def init_test_case(self):
# self.interp_method = 'bilinear' self.interp_method = 'bilinear'
# self.input_shape = [32, 16, 128, 64] self.input_shape = [4, 4, 64, 32]
# self.out_h = 200 self.out_h = 100
# self.out_w = 100 self.out_w = 50
# self.out_size = np.array([201, 101]).astype('int32') self.out_size = np.array([101, 51]).astype('int32')
class TestInterpolateOpUint8(OpTest): class TestInterpolateOpUint8(OpTest):
...@@ -273,6 +273,15 @@ class TestNearestNeighborInterpCase6(TestInterpolateOp): ...@@ -273,6 +273,15 @@ class TestNearestNeighborInterpCase6(TestInterpolateOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
class TestNearestNeighborInterpBigScale(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 4, 64, 32]
self.out_h = 100
self.out_w = 50
self.out_size = np.array([101, 51]).astype('int32')
class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8): class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册