diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc index 5a398fa50febe2efffd588ce8f3612f1f9cec0b6..457d9e79d7da171ef526d5cab0e59b021cb64f98 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc @@ -49,6 +49,8 @@ REGISTER_OP_WITHOUT_GRADIENT(elementwise_floordiv, ops::ElementwiseOp, REGISTER_OP_CPU_KERNEL( elementwise_floordiv, + ops::ElementwiseFloorDivKernel, + ops::ElementwiseFloorDivKernel, ops::ElementwiseFloorDivKernel, ops::ElementwiseFloorDivKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu index 60846d1e8fee1c7f68ac101f18355750c2c15a4d..f63d6f037632c1a6a05726b933b2258adc113ee3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu @@ -19,5 +19,7 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_floordiv, + ops::ElementwiseFloorDivKernel, + ops::ElementwiseFloorDivKernel, ops::ElementwiseFloorDivKernel, ops::ElementwiseFloorDivKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h index 2d24e394d5c823dbd22c837210e46cefeceba1be..a5909aad99a82529a0739cd28b1b72a146524f76 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -24,7 +25,16 @@ namespace operators { template struct FloorDivFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a / b; } + inline HOSTDEVICE T operator()(T a, T b) const { + return static_cast(floor(a / b)); + } +}; + +template +struct InverseFloorDivFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { + return static_cast(floor(b / a)); + } }; template @@ -32,8 +42,15 @@ void elementwise_floor_div(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, FloorDivFunctor(), z); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, FloorDivFunctor(), z); + } else { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, InverseFloorDivFunctor(), z); + } } template diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 4bc43b0d5fcc4aadeb679c4fc40b31e1589bd83b..bb55c6725e6a62f2cef393fd34b249c217be0c54 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -19,6 +19,7 @@ from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator from ..layers.layer_function_generator import OpProtoHolder from ..layers import common_methods from . import to_variable, no_grad +import paddle import numpy as np import six @@ -162,6 +163,26 @@ def monkey_patch_math_varbase(): def _scalar_div_(var, value): return _scalar_elementwise_op_(var, 1.0 / value, 0.0) + # TODO(shenliang03): currently, it supports divide, floor_divide, remainder + # for binary operator by using the api to achieve the type promotion + def _binary_method_creator_(op_type, reverse=False): + import paddle + + def __impl__(self, other_var): + import paddle + op = getattr(paddle, op_type) + if reverse: + return op(other_var, self) + else: + return op(self, other_var) + + __impl__.__doc__ = """ + + See paddle.{}""".format(op_type) + __impl__.__name__ = op_type + + return __impl__ + # for binary operator such as elementwise, compare def _binary_creator_(method_name, op_type, @@ -260,22 +281,20 @@ def monkey_patch_math_varbase(): ## a*b == b*a. Do not need to reverse explicitly ('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)), - ('__div__', _binary_creator_('__div__', 'elementwise_div', False, - _scalar_div_)), - ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', - False, _scalar_div_)), - ('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, - None)), ('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, None)), ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)), ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)), - ('__floordiv__', _binary_creator_('__floordiv__', - 'elementwise_floordiv', False, None)), - ('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, - None)), + # These binary use paddle.optype + ('__div__', _binary_method_creator_('divide', False)), + ('__truediv__', _binary_method_creator_('divide', False)), + ('__rtruediv__', _binary_method_creator_('divide', True)), + ('__rdiv__', _binary_method_creator_('divide', True)), + ('__floordiv__', _binary_method_creator_('floor_divide', False)), + ('__rfloordiv__', _binary_method_creator_('floor_divide', True)), + ('__mod__', _binary_method_creator_('remainder', False)), ## for logical compare ('__eq__', _binary_creator_('__eq__', 'equal', False, None)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)), diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 4595f0cf93916d71a3d0ec582af1917500d68f12..38fc34472c8bc64338e2468bdf3f4b0bab1370ce 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -16,6 +16,7 @@ from __future__ import print_function import warnings import inspect +import paddle from .. import core from ..framework import Variable, unique_name @@ -45,6 +46,7 @@ EXPRESSION_MAP = { "__pow__": "A ** B", "__rpow__": "A **= B", "__floordiv__": "A //B", + "__rfloordiv__": "A //= B", "__mod__": "A % B", "__eq__": "A == B", "__ne__": "A != B", @@ -233,6 +235,25 @@ def monkey_patch_variable(): def _scalar_div_(var, value): return _scalar_op_(var, 1.0 / value, 0.0) + # TODO(shenliang03): currently, it supports divide, floor_divide, remainder + # for binary operator by using the api to achieve the type promotion + def _binary_method_creator_(op_type, reverse=False): + import paddle + + def __impl__(self, other_var): + op = getattr(paddle, op_type) + if reverse: + return op(other_var, self) + else: + return op(self, other_var) + + __impl__.__doc__ = """ + + See paddle.{}""".format(op_type) + __impl__.__name__ = op_type + + return __impl__ + def _binary_creator_(method_name, op_type, reverse=False, @@ -339,22 +360,18 @@ def monkey_patch_variable(): # a*b == b*a. Do not need to reverse explicitly ('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)), - ('__div__', _binary_creator_('__div__', 'elementwise_div', False, - _scalar_div_)), - ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', - False, _scalar_div_)), - ('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, - None)), - ('__rtruediv__', _binary_creator_('__rtruediv__', 'elementwise_div', - True, None)), ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)), ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)), - ('__floordiv__', _binary_creator_('__floordiv__', - 'elementwise_floordiv', False, None)), - ('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, - None)), + # These binary use paddle.optype + ('__div__', _binary_method_creator_('divide', False)), + ('__rdiv__', _binary_method_creator_('divide', True)), + ('__truediv__', _binary_method_creator_('divide', False)), + ('__rtruediv__', _binary_method_creator_('divide', True)), + ('__floordiv__', _binary_method_creator_('floor_divide', False)), + ('__rfloordiv__', _binary_method_creator_('floor_divide', True)), + ('__mod__', _binary_method_creator_('remainder', False)), # for logical compare ('__eq__', _binary_creator_('__eq__', 'equal', False, None)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)), diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py index 761d57408b9a8f9e52419331bfb0bca5b0135c30..1062123948481a4164a12a4bed818b964923006f 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py @@ -113,8 +113,8 @@ class TranspilerAsyncLRDecayTest(unittest.TestCase): ["listen_and_serv"]) # block1: sum,cast,scale,floor,fill_constant,elementwise_pow,scale self.assertEqual([op.type for op in pserver.blocks[1].ops], [ - "sum", "cast", "scale", "floor", "fill_constant", "elementwise_pow", - "scale" + "sum", "cast", "fill_constant", "elementwise_div", "floor", + "fill_constant", "elementwise_pow", "scale" ]) # block1~2: optimize pass diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 3cfbac8b613c125956861f73b1bab24c34e05572..9ebaf8ff9438be8c8a57815be0798b861d05caaf 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -240,25 +240,124 @@ class TestElementwiseDivBroadcast(unittest.TestCase): self.assertEqual((out_result == (2 / x)).all(), True) -class TestDivideOp(unittest.TestCase): - def test_name(self): - with fluid.program_guard(fluid.Program()): - x = fluid.data(name="x", shape=[2, 3], dtype="float32") - y = fluid.data(name='y', shape=[2, 3], dtype='float32') +class TestDivideAPI(unittest.TestCase): + def setUp(self): + paddle.set_default_dtype("float64") + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + # rule 1 + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = np.array([1, 2, 3]) + self.assertRaises(TypeError, paddle.divide, x=x, y=y) + + # rule 2: both the inputs are not Tensor + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = 2 + y = 4 + res = paddle.divide(x, y) + exe = fluid.Executor(place) + np_z = exe.run(fluid.default_main_program(), + feed={}, + fetch_list=[res]) + self.assertEqual(np_z[0] == 0.5, True) + + # rule 3: + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = fluid.data(name="y", shape=[3], dtype="float32") + self.assertRaises(TypeError, paddle.divide, x=x, y=y) + + # rule 4: x is Tensor, y is scalar + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = 2 + exe = fluid.Executor(place) + res = x / y + np_z = exe.run(fluid.default_main_program(), + feed={"x": np.array([2, 3, 4]).astype('float64')}, + fetch_list=[res]) + z_expected = np.array([1., 1.5, 2.]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + # rule 5: y is Tensor, x is scalar + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = 2 + exe = fluid.Executor(place) + res = y / x + np_z = exe.run(fluid.default_main_program(), + feed={"x": np.array([2, 8, 4]).astype('float64')}, + fetch_list=[res]) + z_expected = np.array([1., 0.25, 0.5]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + # rule 6: y is Tensor, x is Tensor + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = fluid.data(name="y", shape=[3], dtype="float64") + exe = fluid.Executor(place) + res = x / y + np_z = exe.run(fluid.default_main_program(), + feed={ + "x": np.array([2, 3, 4]).astype('float64'), + "y": np.array([1, 5, 2]).astype('float64') + }, + fetch_list=[res]) + z_expected = np.array([2., 0.6, 2.]) + self.assertEqual((np_z[0] == z_expected).all(), True) - y_1 = paddle.divide(x, y, name='div_res') - self.assertEqual(('div_res' in y_1.name), True) + def test_static(self): + for place in self.places: + self.check_static_result(place=place) def test_dygraph(self): - with fluid.dygraph.guard(): - np_x = np.array([2, 3, 4]).astype('float64') - np_y = np.array([1, 5, 2]).astype('float64') - x = paddle.to_tensor(np_x) - y = paddle.to_tensor(np_y) - z = paddle.divide(x, y) - np_z = z.numpy() - z_expected = np.array([2., 0.6, 2.]) - self.assertEqual((np_z == z_expected).all(), True) + for place in self.places: + with fluid.dygraph.guard(place): + # rule 1 : avoid numpy.ndarray + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x) + self.assertRaises(TypeError, paddle.divide, x=x, y=np_y) + + # rule 2: both the inputs are not Tensor + z = paddle.divide(3, 2) + self.assertEqual(z.numpy()[0] == 1.5, True) + + # rule 3: both the inputs are Tensor + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x, dtype="float32") + y = paddle.to_tensor(np_y, dtype="float64") + self.assertRaises(TypeError, paddle.divide, x=x, y=y) + + # rule 4: x is Tensor, y is scalar + np_x = np.array([2, 3, 4]) + x = paddle.to_tensor(np_x, dtype="int32") + y = 2 + z = x / y + z_expected = np.array([1., 1.5, 2.]) + self.assertEqual((z_expected == z.numpy()).all(), True) + + # rule 5: y is Tensor, x is scalar + np_x = np.array([2, 1, 4]) + x = paddle.to_tensor(np_x, dtype="int32") + y = 2 + z = y / x + z_expected = np.array([1., 2., 0.5]) + self.assertEqual((z_expected == z.numpy()).all(), True) + + # rule 6: y is Tensor, x is Tensor + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x) + y = paddle.to_tensor(np_y) + z = x / y + z_expected = np.array([2., 0.6, 2.]) + self.assertEqual((z_expected == z.numpy()).all(), True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py index 0c4e19c9816eb03812af252a78c3f2f87cfc153b..4fe085ce854726676bc1b1bef650419b3ebbfc86 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py @@ -58,6 +58,13 @@ class TestElementwiseModOp(OpTest): pass +class TestElementwiseModOpInverse(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, [10]).astype(self.dtype) + self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype) + self.out = np.floor_divide(self.x, self.y) + + class TestElementwiseModOp_scalar(TestElementwiseModOp): def init_input_output(self): scale_x = random.randint(0, 100000000) @@ -67,25 +74,124 @@ class TestElementwiseModOp_scalar(TestElementwiseModOp): self.out = np.floor_divide(self.x, self.y) -class TestFloorDivideOp(unittest.TestCase): - def test_name(self): - with fluid.program_guard(fluid.Program()): - x = fluid.data(name="x", shape=[2, 3], dtype="int64") - y = fluid.data(name='y', shape=[2, 3], dtype='int64') - - y_1 = paddle.floor_divide(x, y, name='div_res') - self.assertEqual(('div_res' in y_1.name), True) +class TestFloorDivideAPI(unittest.TestCase): + def setUp(self): + paddle.set_default_dtype("float64") + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + # rule 1 + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = np.array([1, 2, 3]) + self.assertRaises(TypeError, paddle.floor_divide, x=x, y=y) + + # rule 2: both the inputs are not Tensor + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = 2 + y = 4 + res = paddle.floor_divide(x, y) + exe = fluid.Executor(place) + np_z = exe.run(fluid.default_main_program(), + feed={}, + fetch_list=[res]) + self.assertEqual(np_z[0] == 0., True) + + # rule 3: + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = fluid.data(name="y", shape=[3], dtype="float32") + self.assertRaises(TypeError, paddle.floor_divide, x=x, y=y) + + # rule 4: x is Tensor, y is scalar + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = 2 + exe = fluid.Executor(place) + res = x // y + np_z = exe.run(fluid.default_main_program(), + feed={"x": np.array([2, 3, 4]).astype('float64')}, + fetch_list=[res]) + z_expected = np.array([1., 1., 2.]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + # rule 5: y is Tensor, x is scalar + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = 2 + exe = fluid.Executor(place) + res = y // x + np_z = exe.run(fluid.default_main_program(), + feed={"x": np.array([2, 8, 4]).astype('float64')}, + fetch_list=[res]) + z_expected = np.array([1., 0., 0.]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + # rule 6: y is Tensor, x is Tensor + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = fluid.data(name="y", shape=[3], dtype="float64") + exe = fluid.Executor(place) + res = x // y + np_z = exe.run(fluid.default_main_program(), + feed={ + "x": np.array([2, 3, 4]).astype('float64'), + "y": np.array([1, 5, 2]).astype('float64') + }, + fetch_list=[res]) + z_expected = np.array([2., 0., 2.]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) def test_dygraph(self): - with fluid.dygraph.guard(): - np_x = np.array([2, 3, 8, 7]).astype('int64') - np_y = np.array([1, 5, 3, 3]).astype('int64') - x = paddle.to_tensor(np_x) - y = paddle.to_tensor(np_y) - z = paddle.floor_divide(x, y) - np_z = z.numpy() - z_expected = np.array([2, 0, 2, 2]) - self.assertEqual((np_z == z_expected).all(), True) + for place in self.places: + with fluid.dygraph.guard(place): + # rule 1 : avoid numpy.ndarray + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x) + self.assertRaises(TypeError, paddle.floor_divide, x=x, y=np_y) + + # rule 2: both the inputs are not Tensor + z = paddle.floor_divide(3, 2) + self.assertEqual(z.numpy()[0] == 1., True) + + # rule 3: both the inputs are Tensor + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x, dtype="float32") + y = paddle.to_tensor(np_y, dtype="float64") + self.assertRaises(TypeError, paddle.floor_divide, x=x, y=y) + + # rule 4: x is Tensor, y is scalar + np_x = np.array([2, 3, 4]) + x = paddle.to_tensor(np_x, dtype="int32") + y = 2 + z = x // y + z_expected = np.array([1, 1, 2]) + self.assertEqual((z_expected == z.numpy()).all(), True) + + # rule 5: y is Tensor, x is scalar + np_x = np.array([2, 1, 4]) + x = paddle.to_tensor(np_x, dtype="int32") + y = 2 + z = y // x + z_expected = np.array([1, 2, 0]) + self.assertEqual((z_expected == z.numpy()).all(), True) + + # rule 6: y is Tensor, x is Tensor + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x) + y = paddle.to_tensor(np_y) + z = x // y + z_expected = np.array([2., 0., 2.]) + self.assertEqual((z_expected == z.numpy()).all(), True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py index 601076a1f49ab5603757496d858428847fedd75f..25769a42aa261c0b5ae9fe2795a337c668580a99 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py @@ -84,25 +84,125 @@ class TestElementwiseModOpDouble(TestElementwiseModOpFloat): self.dtype = np.float64 -class TestRemainderOp(unittest.TestCase): - def test_name(self): - with fluid.program_guard(fluid.Program()): - x = fluid.data(name="x", shape=[2, 3], dtype="int64") - y = fluid.data(name='y', shape=[2, 3], dtype='int64') - - y_1 = paddle.remainder(x, y, name='div_res') - self.assertEqual(('div_res' in y_1.name), True) +class TestRemainderAPI(unittest.TestCase): + def setUp(self): + paddle.set_default_dtype("float64") + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + # rule 1 + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = np.array([1, 2, 3]) + self.assertRaises(TypeError, paddle.remainder, x=x, y=y) + + # rule 3: + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = fluid.data(name="y", shape=[3], dtype="float32") + self.assertRaises(TypeError, paddle.remainder, x=x, y=y) + + # rule 4: x is Tensor, y is scalar + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = 2 + exe = fluid.Executor(place) + res = x % y + np_z = exe.run(fluid.default_main_program(), + feed={"x": np.array([2, 3, 4]).astype('float64')}, + fetch_list=[res]) + z_expected = np.array([0., 1., 0.]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + # rule 5: y is Tensor, x is scalar + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = 3 + y = fluid.data(name="y", shape=[3], dtype="float32") + self.assertRaises(TypeError, paddle.remainder, x=x, y=y) + + # rule 6: y is Tensor, x is Tensor + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3], dtype="float64") + y = fluid.data(name="y", shape=[1], dtype="float64") + exe = fluid.Executor(place) + res = x % y + np_z = exe.run(fluid.default_main_program(), + feed={ + "x": np.array([1., 2., 4]).astype('float64'), + "y": np.array([1.5]).astype('float64') + }, + fetch_list=[res]) + z_expected = np.array([1., 0.5, 1.0]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + # rule 6: y is Tensor, x is Tensor + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[6], dtype="float64") + y = fluid.data(name="y", shape=[1], dtype="float64") + exe = fluid.Executor(place) + res = x % y + np_z = exe.run( + fluid.default_main_program(), + feed={ + "x": np.array([-3., -2, -1, 1, 2, 3]).astype('float64'), + "y": np.array([2]).astype('float64') + }, + fetch_list=[res]) + z_expected = np.array([1., 0., 1., 1., 0., 1.]) + self.assertEqual((np_z[0] == z_expected).all(), True) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) def test_dygraph(self): - with fluid.dygraph.guard(): - np_x = np.array([2, 3, 8, 7]).astype('int64') - np_y = np.array([1, 5, 3, 3]).astype('int64') - x = paddle.to_tensor(np_x) - y = paddle.to_tensor(np_y) - z = paddle.remainder(x, y) - np_z = z.numpy() - z_expected = np.array([0, 3, 2, 1]) - self.assertEqual((np_z == z_expected).all(), True) + for place in self.places: + with fluid.dygraph.guard(place): + # rule 1 : avoid numpy.ndarray + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x) + self.assertRaises(TypeError, paddle.remainder, x=x, y=np_y) + + # rule 3: both the inputs are Tensor + np_x = np.array([2, 3, 4]) + np_y = np.array([1, 5, 2]) + x = paddle.to_tensor(np_x, dtype="float32") + y = paddle.to_tensor(np_y, dtype="float64") + self.assertRaises(TypeError, paddle.remainder, x=x, y=y) + + # rule 4: x is Tensor, y is scalar + np_x = np.array([2, 3, 4]) + x = paddle.to_tensor(np_x, dtype="int32") + y = 2 + z = x % y + z_expected = np.array([0, 1, 0]) + self.assertEqual((z_expected == z.numpy()).all(), True) + + # rule 5: y is Tensor, x is scalar + np_x = np.array([2, 3, 4]) + x = paddle.to_tensor(np_x) + self.assertRaises(TypeError, paddle.remainder, x=3, y=x) + + # rule 6: y is Tensor, x is Tensor + np_x = np.array([1., 2., 4]) + np_y = np.array([1.5]) + x = paddle.to_tensor(np_x) + y = paddle.to_tensor(np_y) + z = x % y + z_expected = np.array([1., 0.5, 1.0]) + self.assertEqual((z_expected == z.numpy()).all(), True) + + # rule 6: y is Tensor, x is Tensor + np_x = np.array([-3., -2, -1, 1, 2, 3]) + np_y = np.array([2.]) + x = paddle.to_tensor(np_x) + y = paddle.to_tensor(np_y) + z = x % y + z_expected = np.array([1., 0., 1., 1., 0., 1.]) + self.assertEqual((z_expected == z.numpy()).all(), True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index f6eff22d6ce5f06d8853d6244f79b4b07b3fa4f5..00137f63e244a0e166047e89f9ef436da158ed16 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -189,15 +189,15 @@ class TestMathOpPatches(unittest.TestCase): @prog_scope() def test_integer_div(self): a = fluid.layers.data(name="a", shape=[1], dtype='int64') - b = a / 7 + b = a / 2 place = fluid.CPUPlace() exe = fluid.Executor(place) - a_np = numpy.array([3, 4, 10, 14, 9, 18]).astype('int64') + a_np = numpy.array([3, 4, 10, 14, 9, 18]) b_np, = exe.run(fluid.default_main_program(), feed={"a": a_np}, fetch_list=[b]) - - b_np_actual = (a_np / 7).astype('int64') + # for paddle2.0, use true_divide + b_np_actual = (a_np / 2.0) self.assertTrue(numpy.array_equal(b_np, b_np_actual)) @prog_scope() diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index 6ca194b2694b6c7537ceb94e11eb1a1a0aeb8d8d..7e2ef36c1a7fda5c31049ec9c752c5226bfb89dc 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -248,7 +248,8 @@ class PolicyGradient(object): func=reward_func, x=[action, length], out=reward) neg_log_prob = layers.cross_entropy(act_prob, action) cost = neg_log_prob * reward - cost = (layers.reduce_sum(cost) / layers.reduce_sum(length) + cost = (layers.reduce_sum(cost) / + layers.cast(layers.reduce_sum(length), "float32") ) if length is not None else layers.reduce_mean(cost) optimizer = fluid.optimizer.Adam(self.lr) optimizer.minimize(cost) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 68d0f393455f70bf1a7bf897becc97e51617681c..9069630d8689ad504afac370548de2d50310d281 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1034,7 +1034,8 @@ def ctc_loss(log_probs, loss_out = fluid.layers.squeeze(loss_out, [-1]) assert reduction in ['mean', 'sum', 'none'] if reduction == 'mean': - loss_out = paddle.mean(loss_out / label_lengths) + loss_out = paddle.mean(loss_out / paddle.cast(label_lengths, + loss_out.dtype)) elif reduction == 'sum': loss_out = paddle.sum(loss_out) return loss_out diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 77639e8da466bcfb88f81d2d905a66d374a6d6c1..3b038536a10a77f8ac16c69f63f423b0dd61f35a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -62,6 +62,7 @@ from ..fluid.layers import increment #DEFINE_ALIAS from ..fluid.layers import multiplex #DEFINE_ALIAS from ..fluid.layers import sums #DEFINE_ALIAS from ..fluid import layers +import paddle __all__ = [ 'abs', @@ -133,6 +134,19 @@ __all__ = [ ] # yapf: enable. +_supported_int_dtype_ = [ + VarDesc.VarType.UINT8, + VarDesc.VarType.INT8, + VarDesc.VarType.INT16, + VarDesc.VarType.INT32, + VarDesc.VarType.INT64, +] + +_supported_float_dtype_ = [ + VarDesc.VarType.FP32, + VarDesc.VarType.FP64, +] + @templatedoc() def pow(input, exponent, name=None): """ @@ -308,9 +322,69 @@ def divide(x, y, name=None): axis = -1 act = None if in_dygraph_mode(): + # rule 1 : avoid numpy.ndarray + if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): + raise TypeError("divide(): arguments must be Tensor or scalar, not numpy.ndarray.") + + # rule 2: both the inputs are not Tensor + elif not isinstance(x, paddle.Tensor) and not isinstance(y, paddle.Tensor): + x = paddle.full(shape=[1], dtype=paddle.get_default_dtype(), fill_value=x) + y = paddle.full(shape=[1], dtype=paddle.get_default_dtype(), fill_value=y) + + # rule 3: both the inputs are Tensor + elif isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + if y.dtype != x.dtype: + raise TypeError("divide(): argument position 1 and argument position 2 must have the same dtype." + "But x is {}, y is {}".format(x.dtype, y.dtype)) + elif x.dtype in _supported_int_dtype_: + x = x.astype(paddle.get_default_dtype()) + y = y.astype(paddle.get_default_dtype()) + + # rule 4: x is Tensor, y is scalar + elif isinstance(x, paddle.Tensor) and not isinstance(y, paddle.Tensor): + if x.dtype in _supported_int_dtype_: + x = x.astype(paddle.get_default_dtype()) + y = paddle.full(shape=[1], dtype=x.dtype, fill_value=y) + + # rule 5: x is scalar, y is Tensor + elif not isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + if y.dtype in _supported_int_dtype_: + y = y.astype(paddle.get_default_dtype()) + x = paddle.full(shape=[1], dtype=y.dtype, fill_value=x) + return _elementwise_op_in_dygraph( x, y, axis=axis, act=act, op_name=op_type) + # rule 1 : avoid numpy.ndarray + if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): + raise TypeError("divide(): arguments must be Tensor or scalar, not numpy.ndarray.") + + # rule 2: both the inputs are not Tensor + elif not isinstance(x, Variable) and not isinstance(y, Variable): + x = paddle.fill_constant(shape=[1], dtype=paddle.get_default_dtype(), value=x) + y = paddle.fill_constant(shape=[1], dtype=paddle.get_default_dtype(), value=y) + + # rule 3: both the inputs are Tensor + elif isinstance(x, Variable) and isinstance(y, Variable): + if y.dtype != x.dtype: + raise TypeError("divide(): argument position 1 and argument position 2 must have the same dtype." + "But x is {}, y is {}".format(x.dtype, y.dtype)) + elif x.dtype in _supported_int_dtype_: + x = paddle.cast(x, paddle.get_default_dtype()) + y = paddle.cast(y, paddle.get_default_dtype()) + + # rule 4: x is Tensor, y is scalar + elif isinstance(x, Variable) and not isinstance(y, Variable): + if x.dtype in _supported_int_dtype_: + x = paddle.cast(x, paddle.get_default_dtype()) + y = paddle.fill_constant(shape=[1], dtype=x.dtype, value=y) + + # rule 5: x is scalar, y is Tensor + elif not isinstance(x, Variable) and isinstance(y, Variable): + if y.dtype in _supported_int_dtype_: + y = paddle.cast(y, paddle.get_default_dtype()) + x = paddle.fill_constant(shape=[1], dtype=y.dtype, value=x) + return _elementwise_op(LayerHelper(op_type, **locals())) @@ -352,9 +426,55 @@ def floor_divide(x, y, name=None): op_type = 'elementwise_floordiv' axis = -1 if in_dygraph_mode(): + # rule 1 : avoid numpy.ndarray + if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): + raise TypeError("floor_divide(): arguments must be Tensor or scalar, not numpy.ndarray.") + + # rule 2: both the inputs are not Tensor + elif not isinstance(x, paddle.Tensor) and not isinstance(y, paddle.Tensor): + x = paddle.full(shape=[1], dtype=paddle.get_default_dtype(), fill_value=x) + y = paddle.full(shape=[1], dtype=paddle.get_default_dtype(), fill_value=y) + + # rule 3: both the inputs are Tensor + elif isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + if y.dtype != x.dtype: + raise TypeError("floor_divide(): argument position 1 and argument position 2 must have the same dtype." + "But x is {}, y is {}".format(x.dtype, y.dtype)) + + # rule 4: x is Tensor, y is scalar + elif isinstance(x, paddle.Tensor) and not isinstance(y, paddle.Tensor): + y = paddle.full(shape=[1], dtype=x.dtype, fill_value=y) + + # rule 5: x is scalar, y is Tensor + elif not isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + x = paddle.full(shape=[1], dtype=y.dtype, fill_value=x) + return _elementwise_op_in_dygraph( x, y, axis=axis, op_name=op_type) + # rule 1 : avoid numpy.ndarray + if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): + raise TypeError("divide(): arguments must be Tensor or scalar, not numpy.ndarray.") + + # rule 2: both the inputs are not Tensor + elif not isinstance(x, Variable) and not isinstance(y, Variable): + x = paddle.fill_constant(shape=[1], dtype=paddle.get_default_dtype(), value=x) + y = paddle.fill_constant(shape=[1], dtype=paddle.get_default_dtype(), value=y) + + # rule 3: both the inputs are Tensor + elif isinstance(x, Variable) and isinstance(y, Variable): + if y.dtype != x.dtype: + raise TypeError("divide(): argument position 1 and argument position 2 must have the same dtype." + "But x is {}, y is {}".format(x.dtype, y.dtype)) + + # rule 4: x is Tensor, y is scalar + elif isinstance(x, Variable) and not isinstance(y, Variable): + y = paddle.fill_constant(shape=[1], dtype=x.dtype, value=y) + + # rule 5: x is scalar, y is Tensor + elif not isinstance(x, Variable) and isinstance(y, Variable): + x = paddle.fill_constant(shape=[1], dtype=y.dtype, value=x) + return _elementwise_op(LayerHelper(op_type, **locals())) @@ -396,9 +516,43 @@ def remainder(x, y, name=None): op_type = 'elementwise_mod' axis = -1 if in_dygraph_mode(): + # rule 1 : avoid numpy.ndarray + if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): + raise TypeError("remainder(): arguments must be Tensor or scalar, not numpy.ndarray.") + + elif not isinstance(x, paddle.Tensor): + raise TypeError("remainder(): arguments position 1 must be Tensor, not {}".format(type(x))) + + # rule 3: both the inputs are Tensor + elif isinstance(y, paddle.Tensor): + if y.dtype != x.dtype: + raise TypeError("remainder(): argument position 1 and argument position 2 must have the same dtype." + "But x is {}, y is {}".format(x.dtype, y.dtype)) + + # rule 4: x is Tensor, y is scalar + elif not isinstance(y, paddle.Tensor): + y = paddle.full(shape=[1], dtype=x.dtype, fill_value=y) + return _elementwise_op_in_dygraph( x, y, axis=axis, op_name=op_type) + # rule 1 : avoid numpy.ndarray + if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): + raise TypeError("remainder(): arguments must be Tensor or scalar, not numpy.ndarray.") + + elif not isinstance(x, Variable): + raise TypeError("remainder(): arguments position 1 must be Tensor, not {}".format(type(x))) + + # rule 3: both the inputs are Tensor + elif isinstance(y, Variable): + if y.dtype != x.dtype: + raise TypeError("remainder(): argument position 1 and argument position 2 must have the same dtype." + "But x is {}, y is {}".format(x.dtype, y.dtype)) + + # rule 4: x is Tensor, y is scalar + elif not isinstance(y, paddle.Tensor): + y = paddle.fill_constant(shape=[1], dtype=x.dtype, value=y) + return _elementwise_op(LayerHelper(op_type, **locals()))