未验证 提交 f61e6ee0 编写于 作者: W whs 提交者: GitHub

Fix cuda kernel launch of grid sampler (#33100)

上级 0a9937d2
...@@ -187,7 +187,6 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c, ...@@ -187,7 +187,6 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
int out_sC = out_h * out_w; int out_sC = out_h * out_w;
int out_sH = out_w; int out_sH = out_w;
int out_sW = 1; int out_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) { CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w; const int w = index % out_w;
const int h = (index / out_w) % out_h; const int h = (index / out_w) % out_h;
...@@ -199,7 +198,6 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c, ...@@ -199,7 +198,6 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
ix = compute_positions(ix, in_w, padding_mode, align_corners); ix = compute_positions(ix, in_w, padding_mode, align_corners);
iy = compute_positions(iy, in_h, padding_mode, align_corners); iy = compute_positions(iy, in_h, padding_mode, align_corners);
if (mode == Mode::bilinear) { if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix)); int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy)); int iy_nw = static_cast<int>(floor(iy));
...@@ -216,6 +214,7 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c, ...@@ -216,6 +214,7 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
T se = (ix - ix_nw) * (iy - iy_nw); T se = (ix - ix_nw) * (iy - iy_nw);
auto inp_offset_NC = n * inp_sN; auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c; for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
...@@ -291,17 +290,17 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> { ...@@ -291,17 +290,17 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
<< "; out_w: " << out_w; << "; out_w: " << out_w;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
auto* output_data = output->mutable_data<T>(ctx.GetPlace()); auto* output_data = output->mutable_data<T>(ctx.GetPlace());
VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1]
VLOG(3) << "set constant"; << "; " << output->dims()[2] << "; " << output->dims()[3];
math::SetConstant<paddle::platform::CUDADeviceContext, T>()( math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
dev_ctx, output, static_cast<T>(0)); dev_ctx, output, static_cast<T>(0));
int count = static_cast<int>(n * out_h * out_w); int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream(); auto cu_stream = dev_ctx.stream();
int block_size = 512;
int block = 512; int grid_size = (count + block_size - 1) / block_size;
int grid_size = (count + block - 1) / block; VLOG(3) << "cuda launch - grid dims: " << grid_size << "; block dims"
grid_sample_cuda_kernel<T><<<block, grid_size, 0, cu_stream>>>( << block_size;
grid_sample_cuda_kernel<T><<<grid_size, block_size, 0, cu_stream>>>(
count, n, c, out_h, out_w, in_h, in_w, input->data<T>(), count, n, c, out_h, out_w, in_h, in_w, input->data<T>(),
grid->data<T>(), output_data, mode, padding_mode, align_corners); grid->data<T>(), output_data, mode, padding_mode, align_corners);
} }
...@@ -475,9 +474,12 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -475,9 +474,12 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
int count = static_cast<int>(n * out_h * out_w); int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream(); auto cu_stream = dev_ctx.stream();
int block = 512; int block_size = 512;
int grid_size = (count + block - 1) / block; int grid_size = (count + block_size - 1) / block_size;
grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>( VLOG(3) << "cuda launch grad kernel - grid dims: " << grid_size
<< "; block dims" << block_size << "; count: " << count;
grid_sampler_cuda_backward_kernel<
T><<<grid_size, block_size, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c, count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode, out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners); padding_mode, align_corners);
......
...@@ -19,6 +19,8 @@ import numpy as np ...@@ -19,6 +19,8 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
paddle.enable_static()
def bilinear_interp_np(input, def bilinear_interp_np(input,
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest, skip_check_grad_ci
paddle.enable_static()
def AffineGrid(theta, grid_shape): def AffineGrid(theta, grid_shape):
...@@ -160,7 +162,6 @@ class TestGridSamplerOp(OpTest): ...@@ -160,7 +162,6 @@ class TestGridSamplerOp(OpTest):
"padding_mode": self.padding_mode, "padding_mode": self.padding_mode,
"mode": self.mode "mode": self.mode
} }
# print("X: {}".format(x))
self.outputs = { self.outputs = {
'Output': GridSampler(x, grid, self.align_corners, self.mode, 'Output': GridSampler(x, grid, self.align_corners, self.mode,
self.padding_mode) self.padding_mode)
...@@ -237,5 +238,41 @@ class Case4(TestGridSamplerOp): ...@@ -237,5 +238,41 @@ class Case4(TestGridSamplerOp):
self.numeric_grad_delta = 0.0001 self.numeric_grad_delta = 0.0001
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass")
class LargeInputCase(TestGridSamplerOp):
def get_places(self):
places = []
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def initTestCase(self):
self.no_need_check_grad = True
self.x_shape = (2, 3, 128, 128)
self.grid_shape = (2, 130, 130, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "reflection"
self.mode = "bilinear"
def test_check_grad_normal(self):
pass
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass")
class Case5(LargeInputCase):
def initTestCase(self):
self.no_need_check_grad = True
self.x_shape = (2, 3, 128, 128)
self.grid_shape = (2, 130, 130, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = True
self.padding_mode = "zeros"
self.mode = "bilinear"
self.use_cudnn = False if core.is_compiled_with_rocm() else True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册