diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.cc b/paddle/fluid/operators/elementwise/elementwise_mod_op.cc index fadebc00cf451736f3dbc3a6c4d9d63397582f6f..451c7816b9af1832b8504a05aeb1e0f51c5001c8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.cc @@ -33,4 +33,6 @@ REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod, ops::ElementwiseOp, REGISTER_OP_CPU_KERNEL( elementwise_mod, ops::ElementwiseModKernel, - ops::ElementwiseModKernel); + ops::ElementwiseModKernel, + ops::ElementwiseModFPKernel, + ops::ElementwiseModFPKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu index da3304a83952d448ffcad61f1878b06d354168b9..92991ab3a0a24c0969a403c2e2e2d1b1cb950d2f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu @@ -19,4 +19,6 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_mod, ops::ElementwiseModKernel, - ops::ElementwiseModKernel); + ops::ElementwiseModKernel, + ops::ElementwiseModFPKernel, + ops::ElementwiseModFPKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.h b/paddle/fluid/operators/elementwise/elementwise_mod_op.h index 5b139fd4b33152b4a340c6c5a0f094338bbdffc8..e568a5dc72c08a5673147af0fb9b38bdca3c9921 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.h @@ -27,6 +27,11 @@ struct ModFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a % b; } }; +template +struct ModFunctorFP { + inline HOSTDEVICE T operator()(T a, T b) const { return std::fmod(a, b); } +}; + template void elementwise_mod(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, @@ -36,6 +41,15 @@ void elementwise_mod(const framework::ExecutionContext &ctx, ModFunctor(), z); } +template +void elementwise_mod_fp(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + ModFunctorFP(), z); +} + template class ElementwiseModKernel : public framework::OpKernel { public: @@ -51,5 +65,20 @@ class ElementwiseModKernel : public framework::OpKernel { } }; +template +class ElementwiseModFPKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); + auto *z = ctx.Output("Out"); + + z->mutable_data(ctx.GetPlace()); + + // dtype of x and y is float or double + elementwise_mod_fp(ctx, x, y, z); + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py index a354ba0177ae70ba4f3a1565360f96a55edd33b6..fcda179a093cf2306b9d79264c9118fe3b68b35c 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py @@ -27,7 +27,6 @@ class TestElementwiseModOp(OpTest): def setUp(self): self.op_type = "elementwise_mod" - self.dtype = np.int32 self.axis = -1 self.init_dtype() self.init_input_output() @@ -50,7 +49,7 @@ class TestElementwiseModOp(OpTest): self.out = np.mod(self.x, self.y) def init_dtype(self): - pass + self.dtype = np.int32 def init_axis(self): pass @@ -65,5 +64,23 @@ class TestElementwiseModOp_scalar(TestElementwiseModOp): self.out = np.mod(self.x, self.y) +class TestElementwiseModOpFloat(TestElementwiseModOp): + def init_dtype(self): + self.dtype = np.float32 + + def init_input_output(self): + self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype) + self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype) + self.out = np.fmod(self.x, self.y) + + def test_check_output(self): + self.check_output(atol=2e-5) + + +class TestElementwiseModOpDouble(TestElementwiseModOpFloat): + def init_dtype(self): + self.dtype = np.float64 + + if __name__ == '__main__': unittest.main()