未验证 提交 b6d0dac9 编写于 作者: Y Yuang Liu 提交者: GitHub

Fix roll kernel gpu bug. (#52012)

上级 5031b443
......@@ -39,7 +39,7 @@ __global__ void RollCudaKernel(const T* input,
#pragma unroll
for (size_t i = 0; i < Rank; i++) {
new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i];
new_dim_idx = (output_idx / strides[i]) % sizes[i] + shifts[i];
if (new_dim_idx >= sizes[i]) {
output_idx += (shifts[i] - sizes[i]) * strides[i];
} else {
......
......@@ -61,6 +61,14 @@ class TestRollOpCase2(TestRollOp):
self.axis = [-1, -2]
class TestRollOpCase3(TestRollOp):
def init_dtype_type(self):
self.dtype = np.float32
self.x_shape = (11, 11)
self.shifts = [1, 1]
self.axis = [-1, 1]
class TestRollFP16OP(TestRollOp):
def init_dtype_type(self):
self.dtype = np.float16
......@@ -77,6 +85,14 @@ class TestRollFP16OpCase2(TestRollOp):
self.axis = [-1, -2]
class TestRollFP16OpCase3(TestRollOp):
def init_dtype_type(self):
self.dtype = np.float16
self.x_shape = (11, 11)
self.shifts = [1, 1]
self.axis = [-1, 1]
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......@@ -117,6 +133,26 @@ class TestRollBF16OpCase2(TestRollOp):
self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=True)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestRollBF16OpCase3(TestRollOp):
def init_dtype_type(self):
self.dtype = np.uint16
self.x_shape = (11, 11)
self.shifts = [1, 1]
self.axis = [-1, 1]
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place, check_eager=True)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=True)
class TestRollAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册