From 419188d77f3645ed23b77d7c028bf03fbb518f69 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Sun, 20 Oct 2019 21:24:55 +0800 Subject: [PATCH] [cherry-pick]add assertions on whether elementwise_div divison is zero (#20713) --- .../elementwise/elementwise_div_op.cc | 9 ++--- .../elementwise/elementwise_op_function.cu.h | 36 +++++++++++++++++-- .../unittests/test_elementwise_div_op.py | 20 +++++++++++ 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 000055a4b17..507b5a4ed7a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -32,6 +32,8 @@ struct SameDimsElemwiseDiv< } }; +// use default div function for int32/int64 type because of divison zero +// checking. template struct SameDimsElemwiseDiv< platform::CPUDeviceContext, T, @@ -39,12 +41,7 @@ struct SameDimsElemwiseDiv< void operator()(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - auto &place = *ctx.template device_context() - .eigen_device(); - eigen_z.device(place) = eigen_x / eigen_y; + default_elementwise_div(ctx, x, y, z); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h index 263f6225548..d4b618860e6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h @@ -1,8 +1,11 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -12,9 +15,9 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" - #define PADDLE_CUDA_THREAD_SIZE 512 #ifdef PADDLE_WITH_CUDA @@ -29,11 +32,14 @@ limitations under the License. */ #define __h2div h2div #endif +#define DIV_ERROR_INFO \ + "InvalidArgumentError: Integer division by zero encountered in " \ + "divide.Please check.\n" namespace paddle { namespace operators { #define DEFINE_SIMPLE_BINARY_FUNCTOR(Func, expr) \ - template \ + template \ struct Func##Functor { \ inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ return a expr b; \ @@ -46,8 +52,18 @@ DEFINE_SIMPLE_BINARY_FUNCTOR(Mul, *) DEFINE_SIMPLE_BINARY_FUNCTOR(Div, /) #undef DEFINE_SIMPLE_BINARY_FUNCTOR +// special div functor for int32/int64. check divison has a zero +template +struct DivFunctor::value>::type> { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return a / b; + } +}; + #define DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Func, expr) \ - template \ + template \ struct Func##RangeFunctor { \ Func##RangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} \ inline HOSTDEVICE void operator()(size_t id) const { \ @@ -63,6 +79,20 @@ DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Mul, *) DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Div, /) #undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR +// special div functor for int32/int64. check divison has a zero +template +struct DivRangeFunctor< + T, typename std::enable_if::value>::type> { + DivRangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} + inline HOSTDEVICE void operator()(size_t id) const { + PADDLE_ENFORCE(y_[id] != 0, DIV_ERROR_INFO); + z_[id] = x_[id] / y_[id]; + } + const T* x_; + const T* y_; + T* z_; +}; + #ifdef PADDLE_CUDA_FP16 inline DEVICE half2 half2_add(const half2& a, const half2& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 4e679607d13..4046e31b416 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -151,6 +151,26 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp): self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} +class TestElementwiseDivOp_INT(OpTest): + def setUp(self): + self.op_type = "elementwise_div" + self.dtype = np.int32 + self.init_dtype() + self.inputs = { + 'X': np.random.randint( + 1, 5, size=[2, 3]).astype(self.dtype), + 'Y': np.random.randint( + 1, 5, size=[2, 3]).astype(self.dtype) + } + self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} + + def test_check_output(self): + self.check_output() + + def init_dtype(self): + pass + + class TestElementwiseDivOpFp16(ElementwiseDivOp): def init_dtype(self): self.dtype = np.float16 -- GitLab