未验证 提交 92e256a8 编写于 作者: L Leo Chen 提交者: GitHub

Fix bug of different dtype in dygraph math_op_patch, test=develop (#24740)

* Fix bug of different dtype in dygraph math_op_patch, test=develop

* support np.dtype and str, test=develop

* add unit test, test=develop
上级 181b1f5a
...@@ -37,9 +37,6 @@ def monkey_patch_math_varbase(): ...@@ -37,9 +37,6 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance. The difference is, in dygraph mode, use auto-generated op functions for better performance.
""" """
def safe_get_dtype(var):
return var.dtype
@no_grad @no_grad
def create_tensor(value, dtype, shape): def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
...@@ -96,8 +93,9 @@ def monkey_patch_math_varbase(): ...@@ -96,8 +93,9 @@ def monkey_patch_math_varbase():
print("new var's dtype is: {}, numpy dtype is {}".format(new_variable.dtype, new_variable.numpy().dtype)) print("new var's dtype is: {}, numpy dtype is {}".format(new_variable.dtype, new_variable.numpy().dtype))
""" """
return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', if not isinstance(dtype, core.VarDesc.VarType):
convert_np_dtype_to_dtype_(dtype)) dtype = convert_np_dtype_to_dtype_(dtype)
return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', dtype)
def _scalar_elementwise_op_(var, scale, bias): def _scalar_elementwise_op_(var, scale, bias):
return core.ops.scale(var, 'scale', scale, 'bias', bias) return core.ops.scale(var, 'scale', scale, 'bias', bias)
...@@ -175,7 +173,7 @@ def monkey_patch_math_varbase(): ...@@ -175,7 +173,7 @@ def monkey_patch_math_varbase():
elif isinstance(other_var, int): elif isinstance(other_var, int):
return scalar_method(self, float(other_var)) return scalar_method(self, float(other_var))
lhs_dtype = safe_get_dtype(self) lhs_dtype = self.dtype
if not isinstance(other_var, core.VarBase): if not isinstance(other_var, core.VarBase):
if reverse: if reverse:
...@@ -185,7 +183,7 @@ def monkey_patch_math_varbase(): ...@@ -185,7 +183,7 @@ def monkey_patch_math_varbase():
# add fill_op # add fill_op
other_var = create_scalar(value=other_var, dtype=lhs_dtype) other_var = create_scalar(value=other_var, dtype=lhs_dtype)
rhs_dtype = safe_get_dtype(other_var) rhs_dtype = other_var.dtype
if lhs_dtype != rhs_dtype: if lhs_dtype != rhs_dtype:
other_var = astype(other_var, lhs_dtype) other_var = astype(other_var, lhs_dtype)
if reverse: if reverse:
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from decorator_helper import prog_scope
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import six import six
...@@ -23,7 +22,7 @@ import six ...@@ -23,7 +22,7 @@ import six
class TestMathOpPatchesVarBase(unittest.TestCase): class TestMathOpPatchesVarBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.shape = [10, 10] self.shape = [10, 1024]
self.dtype = np.float32 self.dtype = np.float32
def test_add(self): def test_add(self):
...@@ -251,6 +250,29 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -251,6 +250,29 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
rtol=1e-05, rtol=1e-05,
atol=0.0)) atol=0.0))
def test_add_different_dtype(self):
a_np = np.random.random(self.shape).astype(np.float32)
b_np = np.random.random(self.shape).astype(np.float16)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
b = fluid.dygraph.to_variable(b_np)
res = a + b
self.assertTrue(np.array_equal(res.numpy(), a_np + b_np))
def test_astype(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
res1 = a.astype(np.float16)
res2 = a.astype('float16')
res3 = a.astype(fluid.core.VarDesc.VarType.FP16)
self.assertEqual(res1.dtype, res2.dtype)
self.assertEqual(res1.dtype, res3.dtype)
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
self.assertTrue(np.array_equal(res1.numpy(), res3.numpy()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册