diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index 34d4d67e39d53442a7a8d177292427a933e518b7..136c5c0aca8b320a2f0e7d2a896b8e982c1086b4 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -80,9 +80,11 @@ class RollKernel int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); int64_t size = input_dim[dim]; - shifts[i] = (shifts[i] % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; + if (size != 0) { + shifts[i] = (shifts[i] % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } } } @@ -151,10 +153,11 @@ class RollGradKernel for (size_t i = 0; i < nums; i++) { int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); int64_t size = input_dim[dim]; - - shifts[i] = ((-shifts[i]) % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; + if (size != 0) { + shifts[i] = ((-shifts[i]) % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } } } diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h index da4f335ca7faa62504b6426bce37c63c4e0f17e3..e58ff521d8df77ec12fc28a839d7bdfe4e699595 100644 --- a/paddle/fluid/operators/roll_op.h +++ b/paddle/fluid/operators/roll_op.h @@ -30,6 +30,9 @@ inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim, if (dim < 0) { dim += input_dim.size(); } + if (input_dim[dim] == 0) { + return; + } shift = shift % input_dim[dim]; if (shift < 0) { shift += input_dim[dim];