diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu index 92991ab3a0a24c0969a403c2e2e2d1b1cb950d2f..bb49fdbf12dfa36ae2127eccc1c189939bda9a2e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu @@ -12,13 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mod_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; +namespace paddle { +namespace operators { + +template +struct CudaModFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + T res = args[0] % args[1]; + + // Accoding to #PR26732: in dividen % divsor + // remainder shall have the same sign as divsor. + if ((res != 0) && ((args[1] ^ res) < 0)) res += args[1]; + return res; + } +}; + +template +struct CudaModFunctor< + T, typename std::enable_if_t::value>> { + inline HOSTDEVICE T operator()(const T* args) const { + T res = fmod(args[0], args[1]); + + // Accoding to #PR26732: in dividen % divsor + // remainder shall have the same sign as divsor. + if ((res != 0) && ((res < 0) != (args[1] < 0))) res += args[1]; + return res; + } +}; + +template +class ElementwiseModKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaModFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_mod, ops::ElementwiseModKernel, ops::ElementwiseModKernel, - ops::ElementwiseModFPKernel, - ops::ElementwiseModFPKernel); + ops::ElementwiseModKernel, + ops::ElementwiseModKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.h b/paddle/fluid/operators/elementwise/elementwise_mod_op.h index 87e940e2ed6319c4f2957cd846735adb210cd23d..03884f2a45883bbb55bf2b2655636bb003084147 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.h @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle {