diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.h b/paddle/fluid/operators/elementwise/elementwise_mod_op.h index 47bd6af0b95ace2b9b753e38cfc5f191bc1bb942..87e940e2ed6319c4f2957cd846735adb210cd23d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.h @@ -31,6 +31,15 @@ struct ModFunctor { } }; +template +struct InverseModFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { + T res = b % a; + if ((res != 0) && ((res < 0) != (a < 0))) res += a; + return res; + } +}; + template struct ModFunctorFP { inline HOSTDEVICE T operator()(T a, T b) const { @@ -40,13 +49,29 @@ struct ModFunctorFP { } }; +template +struct InverseModFunctorFP { + inline HOSTDEVICE T operator()(T a, T b) const { + T res = fmod(b, a); + if ((res != 0) && ((a < 0) != (res < 0))) res += a; + return res; + } +}; + template void elementwise_mod(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - ModFunctor(), z); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + ModFunctor(), z); + } else { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, InverseModFunctor(), z); + } } template @@ -54,8 +79,15 @@ void elementwise_mod_fp(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - ModFunctorFP(), z); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, ModFunctorFP(), z); + } else { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, InverseModFunctorFP(), z); + } } template diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py index f5d8b4f704da8acd97475444346522f63d3724fd..cab6160d761004877896deea8d44ca02c9de2e1e 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py @@ -220,6 +220,14 @@ class TestRemainderAPI(unittest.TestCase): z_expected = np.array([0, 1, 1, -1]) self.assertEqual(np.allclose(z_expected, z.numpy()), True) + np_x = np.array([-3, 3]) + np_y = np.array([[2, 3], [-2, -1]]) + x = paddle.to_tensor(np_x, dtype="int64") + y = paddle.to_tensor(np_y, dtype="int64") + z = x % y + z_expected = np.array([[1, 0], [-1, 0]]) + self.assertEqual(np.allclose(z_expected, z.numpy()), True) + if __name__ == '__main__': unittest.main()