From 5571c98f8b57121e12dd07a11ff043f9adbc4928 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 30 Jul 2021 02:37:40 -0500 Subject: [PATCH] Fix roll_op by avoiding DivisionByZeroError, test=develop (#34499) --- paddle/fluid/operators/roll_op.cu | 17 ++++++++++------- paddle/fluid/operators/roll_op.h | 3 +++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index 34d4d67e39d..136c5c0aca8 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -80,9 +80,11 @@ class RollKernel int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); int64_t size = input_dim[dim]; - shifts[i] = (shifts[i] % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; + if (size != 0) { + shifts[i] = (shifts[i] % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } } } @@ -151,10 +153,11 @@ class RollGradKernel for (size_t i = 0; i < nums; i++) { int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); int64_t size = input_dim[dim]; - - shifts[i] = ((-shifts[i]) % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; + if (size != 0) { + shifts[i] = ((-shifts[i]) % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } } } diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h index da4f335ca7f..e58ff521d8d 100644 --- a/paddle/fluid/operators/roll_op.h +++ b/paddle/fluid/operators/roll_op.h @@ -30,6 +30,9 @@ inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim, if (dim < 0) { dim += input_dim.size(); } + if (input_dim[dim] == 0) { + return; + } shift = shift % input_dim[dim]; if (shift < 0) { shift += input_dim[dim]; -- GitLab