未验证 提交 5571c98f 编写于 作者: H Haohongxiang 提交者: GitHub

Fix roll_op by avoiding DivisionByZeroError, test=develop (#34499)

上级 ba19398e
......@@ -80,9 +80,11 @@ class RollKernel<platform::CUDADeviceContext, T>
int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
int64_t size = input_dim[dim];
shifts[i] = (shifts[i] % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
if (size != 0) {
shifts[i] = (shifts[i] % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}
......@@ -151,10 +153,11 @@ class RollGradKernel<platform::CUDADeviceContext, T>
for (size_t i = 0; i < nums; i++) {
int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
int64_t size = input_dim[dim];
shifts[i] = ((-shifts[i]) % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
if (size != 0) {
shifts[i] = ((-shifts[i]) % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}
......
......@@ -30,6 +30,9 @@ inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim,
if (dim < 0) {
dim += input_dim.size();
}
if (input_dim[dim] == 0) {
return;
}
shift = shift % input_dim[dim];
if (shift < 0) {
shift += input_dim[dim];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册