未验证 提交 92e256a8 编写于 作者: L Leo Chen 提交者: GitHub

Fix bug of different dtype in dygraph math_op_patch, test=develop (#24740)

* Fix bug of different dtype in dygraph math_op_patch, test=develop

* support np.dtype and str, test=develop

* add unit test, test=develop
上级 181b1f5a
......@@ -37,9 +37,6 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance.
"""
def safe_get_dtype(var):
return var.dtype
@no_grad
def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype)
......@@ -96,8 +93,9 @@ def monkey_patch_math_varbase():
print("new var's dtype is: {}, numpy dtype is {}".format(new_variable.dtype, new_variable.numpy().dtype))
"""
return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype',
convert_np_dtype_to_dtype_(dtype))
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', dtype)
def _scalar_elementwise_op_(var, scale, bias):
return core.ops.scale(var, 'scale', scale, 'bias', bias)
......@@ -175,7 +173,7 @@ def monkey_patch_math_varbase():
elif isinstance(other_var, int):
return scalar_method(self, float(other_var))
lhs_dtype = safe_get_dtype(self)
lhs_dtype = self.dtype
if not isinstance(other_var, core.VarBase):
if reverse:
......@@ -185,7 +183,7 @@ def monkey_patch_math_varbase():
# add fill_op
other_var = create_scalar(value=other_var, dtype=lhs_dtype)
rhs_dtype = safe_get_dtype(other_var)
rhs_dtype = other_var.dtype
if lhs_dtype != rhs_dtype:
other_var = astype(other_var, lhs_dtype)
if reverse:
......
......@@ -15,7 +15,6 @@
from __future__ import print_function
import unittest
from decorator_helper import prog_scope
import paddle.fluid as fluid
import numpy as np
import six
......@@ -23,7 +22,7 @@ import six
class TestMathOpPatchesVarBase(unittest.TestCase):
def setUp(self):
self.shape = [10, 10]
self.shape = [10, 1024]
self.dtype = np.float32
def test_add(self):
......@@ -251,6 +250,29 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
rtol=1e-05,
atol=0.0))
def test_add_different_dtype(self):
a_np = np.random.random(self.shape).astype(np.float32)
b_np = np.random.random(self.shape).astype(np.float16)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
b = fluid.dygraph.to_variable(b_np)
res = a + b
self.assertTrue(np.array_equal(res.numpy(), a_np + b_np))
def test_astype(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
res1 = a.astype(np.float16)
res2 = a.astype('float16')
res3 = a.astype(fluid.core.VarDesc.VarType.FP16)
self.assertEqual(res1.dtype, res2.dtype)
self.assertEqual(res1.dtype, res3.dtype)
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
self.assertTrue(np.array_equal(res1.numpy(), res3.numpy()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册