未验证 提交 b6d0dac9 编写于 作者: Y Yuang Liu 提交者: GitHub

Fix roll kernel gpu bug. (#52012)

上级 5031b443
...@@ -39,7 +39,7 @@ __global__ void RollCudaKernel(const T* input, ...@@ -39,7 +39,7 @@ __global__ void RollCudaKernel(const T* input,
#pragma unroll #pragma unroll
for (size_t i = 0; i < Rank; i++) { for (size_t i = 0; i < Rank; i++) {
new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i]; new_dim_idx = (output_idx / strides[i]) % sizes[i] + shifts[i];
if (new_dim_idx >= sizes[i]) { if (new_dim_idx >= sizes[i]) {
output_idx += (shifts[i] - sizes[i]) * strides[i]; output_idx += (shifts[i] - sizes[i]) * strides[i];
} else { } else {
......
...@@ -61,6 +61,14 @@ class TestRollOpCase2(TestRollOp): ...@@ -61,6 +61,14 @@ class TestRollOpCase2(TestRollOp):
self.axis = [-1, -2] self.axis = [-1, -2]
class TestRollOpCase3(TestRollOp):
def init_dtype_type(self):
self.dtype = np.float32
self.x_shape = (11, 11)
self.shifts = [1, 1]
self.axis = [-1, 1]
class TestRollFP16OP(TestRollOp): class TestRollFP16OP(TestRollOp):
def init_dtype_type(self): def init_dtype_type(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -77,6 +85,14 @@ class TestRollFP16OpCase2(TestRollOp): ...@@ -77,6 +85,14 @@ class TestRollFP16OpCase2(TestRollOp):
self.axis = [-1, -2] self.axis = [-1, -2]
class TestRollFP16OpCase3(TestRollOp):
def init_dtype_type(self):
self.dtype = np.float16
self.x_shape = (11, 11)
self.shifts = [1, 1]
self.axis = [-1, 1]
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)), or not core.is_bfloat16_supported(core.CUDAPlace(0)),
...@@ -117,6 +133,26 @@ class TestRollBF16OpCase2(TestRollOp): ...@@ -117,6 +133,26 @@ class TestRollBF16OpCase2(TestRollOp):
self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=True) self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=True)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestRollBF16OpCase3(TestRollOp):
def init_dtype_type(self):
self.dtype = np.uint16
self.x_shape = (11, 11)
self.shifts = [1, 1]
self.axis = [-1, 1]
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place, check_eager=True)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=True)
class TestRollAPI(unittest.TestCase): class TestRollAPI(unittest.TestCase):
def input_data(self): def input_data(self):
self.data_x = np.array( self.data_x = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册