From 1017180695720dc6cf223d408dd912a04eb6c19b Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Wed, 23 Jun 2021 14:01:12 +0800 Subject: [PATCH] Support Mod in elementwise system (#33052) --- .../elementwise/elementwise_mod_op.cu | 51 ++++++++++++++++++- .../elementwise/elementwise_mod_op.h | 1 - 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu index 92991ab3a0a..bb49fdbf12d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu @@ -12,13 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mod_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; +namespace paddle { +namespace operators { + +template +struct CudaModFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + T res = args[0] % args[1]; + + // Accoding to #PR26732: in dividen % divsor + // remainder shall have the same sign as divsor. + if ((res != 0) && ((args[1] ^ res) < 0)) res += args[1]; + return res; + } +}; + +template +struct CudaModFunctor< + T, typename std::enable_if_t::value>> { + inline HOSTDEVICE T operator()(const T* args) const { + T res = fmod(args[0], args[1]); + + // Accoding to #PR26732: in dividen % divsor + // remainder shall have the same sign as divsor. + if ((res != 0) && ((res < 0) != (args[1] < 0))) res += args[1]; + return res; + } +}; + +template +class ElementwiseModKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaModFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_mod, ops::ElementwiseModKernel, ops::ElementwiseModKernel, - ops::ElementwiseModFPKernel, - ops::ElementwiseModFPKernel); + ops::ElementwiseModKernel, + ops::ElementwiseModKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.h b/paddle/fluid/operators/elementwise/elementwise_mod_op.h index 87e940e2ed6..03884f2a458 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.h @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { -- GitLab