diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index b15e6388884c7b194654f3c3b71b8c111f76e565..d2c779a85497917179736777dac25efa7cfba228 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -37,9 +37,6 @@ def monkey_patch_math_varbase(): The difference is, in dygraph mode, use auto-generated op functions for better performance. """ - def safe_get_dtype(var): - return var.dtype - @no_grad def create_tensor(value, dtype, shape): out = _varbase_creator(dtype=dtype) @@ -96,8 +93,9 @@ def monkey_patch_math_varbase(): print("new var's dtype is: {}, numpy dtype is {}".format(new_variable.dtype, new_variable.numpy().dtype)) """ - return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', - convert_np_dtype_to_dtype_(dtype)) + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', dtype) def _scalar_elementwise_op_(var, scale, bias): return core.ops.scale(var, 'scale', scale, 'bias', bias) @@ -175,7 +173,7 @@ def monkey_patch_math_varbase(): elif isinstance(other_var, int): return scalar_method(self, float(other_var)) - lhs_dtype = safe_get_dtype(self) + lhs_dtype = self.dtype if not isinstance(other_var, core.VarBase): if reverse: @@ -185,7 +183,7 @@ def monkey_patch_math_varbase(): # add fill_op other_var = create_scalar(value=other_var, dtype=lhs_dtype) - rhs_dtype = safe_get_dtype(other_var) + rhs_dtype = other_var.dtype if lhs_dtype != rhs_dtype: other_var = astype(other_var, lhs_dtype) if reverse: diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 34f14b759526026f66930a4ba5322b5363fbb50f..14c3fb7c8bf06908bd8eabffbe46887a6546f6d2 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -15,7 +15,6 @@ from __future__ import print_function import unittest -from decorator_helper import prog_scope import paddle.fluid as fluid import numpy as np import six @@ -23,7 +22,7 @@ import six class TestMathOpPatchesVarBase(unittest.TestCase): def setUp(self): - self.shape = [10, 10] + self.shape = [10, 1024] self.dtype = np.float32 def test_add(self): @@ -251,6 +250,29 @@ class TestMathOpPatchesVarBase(unittest.TestCase): rtol=1e-05, atol=0.0)) + def test_add_different_dtype(self): + a_np = np.random.random(self.shape).astype(np.float32) + b_np = np.random.random(self.shape).astype(np.float16) + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(a_np) + b = fluid.dygraph.to_variable(b_np) + res = a + b + self.assertTrue(np.array_equal(res.numpy(), a_np + b_np)) + + def test_astype(self): + a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(a_np) + res1 = a.astype(np.float16) + res2 = a.astype('float16') + res3 = a.astype(fluid.core.VarDesc.VarType.FP16) + + self.assertEqual(res1.dtype, res2.dtype) + self.assertEqual(res1.dtype, res3.dtype) + + self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) + self.assertTrue(np.array_equal(res1.numpy(), res3.numpy())) + if __name__ == '__main__': unittest.main()