diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 3b9ece483000e44ebd035c02d3ce12346c3943fc..190afbdac431f863c32e2a4a4b3ad83848e550fc 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw( const size_t num_channels, const float ratio_h, const float ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw( const size_t num_channels, const float ratio_h, const float ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw( const size_t num_channels, const float ratio_h, const float ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw( const size_t num_channels, const T ratio_h, const T ratio_w) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; @@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { return; } - int threadNum = n * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; if ("nearest" == interp_method) { KeNearestNeighborInterpFw< - T><<>>( + T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } else if ("bilinear" == interp_method) { KeBilinearInterpFw< - T><<>>( + T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } @@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { return; } - int threadNum = n * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; if ("nearest" == interp_method) { KeNearestNeighborInterpBw< - T><<>>( + T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } else if ("bilinear" == interp_method) { KeBilinearInterpBw< - T><<>>( + T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w); } diff --git a/python/paddle/fluid/tests/unittests/test_interpolate_op.py b/python/paddle/fluid/tests/unittests/test_interpolate_op.py index a90f4aace2a8a501c833ecb2f3d2c0d97a2a9ecb..dd3bf5fd5c91ea2c291a651b7f68794483f79540 100644 --- a/python/paddle/fluid/tests/unittests/test_interpolate_op.py +++ b/python/paddle/fluid/tests/unittests/test_interpolate_op.py @@ -167,13 +167,13 @@ class TestBilinearInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") -# class TestBilinearInterpBigScale(TestInterpolateOp): -# def init_test_case(self): -# self.interp_method = 'bilinear' -# self.input_shape = [32, 16, 128, 64] -# self.out_h = 200 -# self.out_w = 100 -# self.out_size = np.array([201, 101]).astype('int32') +class TestBilinearInterpBigScale(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 4, 64, 32] + self.out_h = 100 + self.out_w = 50 + self.out_size = np.array([101, 51]).astype('int32') class TestInterpolateOpUint8(OpTest): @@ -273,6 +273,15 @@ class TestNearestNeighborInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") +class TestNearestNeighborInterpBigScale(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 4, 64, 32] + self.out_h = 100 + self.out_w = 50 + self.out_size = np.array([101, 51]).astype('int32') + + class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8): def init_test_case(self): self.interp_method = 'nearest'