未验证 提交 b027652b 编写于 作者: G Guoxia Wang 提交者: GitHub

remove tmp fp32 var for gaussian_random (#46285)

上级 3e8b3220
...@@ -374,19 +374,6 @@ class NormalInitializer(Initializer): ...@@ -374,19 +374,6 @@ class NormalInitializer(Initializer):
["uint16", "float16", "float32", "float64"], ["uint16", "float16", "float32", "float64"],
"guassian_random") "guassian_random")
# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(name=unique_name.generate(".".join(
['normal_init', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -394,48 +381,29 @@ class NormalInitializer(Initializer): ...@@ -394,48 +381,29 @@ class NormalInitializer(Initializer):
place = _current_expected_place() place = _current_expected_place()
out_var = _C_ops.gaussian_random(var.shape, self._mean, out_var = _C_ops.gaussian_random(var.shape, self._mean,
self._std_dev, self._seed, self._std_dev, self._seed,
out_dtype, place) var.dtype, place)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
return None return None
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out_var = _legacy_C_ops.gaussian_random( out_var = _legacy_C_ops.gaussian_random(
'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean, 'shape', var.shape, 'dtype', var.dtype, 'mean', self._mean,
'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False) 'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
return None return None
else: else:
op = block.append_op(type="gaussian_random", op = block.append_op(type="gaussian_random",
outputs={"Out": out_var}, outputs={"Out": var},
attrs={ attrs={
"shape": var.shape, "shape": var.shape,
"dtype": out_dtype, "dtype": var.dtype,
"mean": self._mean, "mean": self._mean,
"std": self._std_dev, "std": self._std_dev,
"seed": self._seed, "seed": self._seed,
"use_mkldnn": False "use_mkldnn": False
}, },
stop_gradient=True) stop_gradient=True)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
block.append_op(type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={
"in_dtype": out_var.dtype,
"out_dtype": var.dtype
})
var.op = op var.op = op
return op return op
...@@ -695,7 +663,7 @@ class XavierInitializer(Initializer): ...@@ -695,7 +663,7 @@ class XavierInitializer(Initializer):
outputs={"Out": out_var}, outputs={"Out": out_var},
attrs={ attrs={
"shape": out_var.shape, "shape": out_var.shape,
"dtype": out_dtype, "dtype": out_var.dtype,
"mean": 0.0, "mean": 0.0,
"std": std, "std": std,
"seed": self._seed "seed": self._seed
......
...@@ -245,7 +245,7 @@ class TestNormalInitializer(unittest.TestCase): ...@@ -245,7 +245,7 @@ class TestNormalInitializer(unittest.TestCase):
name="param", name="param",
initializer=initializer.NormalInitializer( initializer=initializer.NormalInitializer(
2.3, 1.9, 123)) 2.3, 1.9, 123))
num_ops = 2 if (dtype == "float16" or dtype == "uint16") else 1 num_ops = 1
self.assertEqual(len(block.ops), num_ops) self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0] init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random') self.assertEqual(init_op.type, 'gaussian_random')
...@@ -390,7 +390,6 @@ class TestXavierInitializer(unittest.TestCase): ...@@ -390,7 +390,6 @@ class TestXavierInitializer(unittest.TestCase):
"""Test the Xavier initializer with float16 """Test the Xavier initializer with float16
""" """
block = self.test_xavier_initializer_supplied_arguments("float16") block = self.test_xavier_initializer_supplied_arguments("float16")
self.assertTrue(check_cast_op(block.ops[1]))
def test_xavier_initializer_bf16(self): def test_xavier_initializer_bf16(self):
"""Test the Xavier initializer with bfloat16 """Test the Xavier initializer with bfloat16
...@@ -400,7 +399,6 @@ class TestXavierInitializer(unittest.TestCase): ...@@ -400,7 +399,6 @@ class TestXavierInitializer(unittest.TestCase):
self.assertEqual(len(block_uniform.ops), 1) self.assertEqual(len(block_uniform.ops), 1)
block_gaussian = self.test_xavier_initializer_supplied_arguments( block_gaussian = self.test_xavier_initializer_supplied_arguments(
"uint16", False) "uint16", False)
self.assertTrue(check_cast_op(block_gaussian.ops[1]))
class TestMSRAInitializer(unittest.TestCase): class TestMSRAInitializer(unittest.TestCase):
......
...@@ -398,7 +398,7 @@ class TestNormal(unittest.TestCase): ...@@ -398,7 +398,7 @@ class TestNormal(unittest.TestCase):
lod_level=0, lod_level=0,
name="param", name="param",
initializer=initializer.Normal(2.3, 1.9)) initializer=initializer.Normal(2.3, 1.9))
num_ops = 2 if dtype in ["float16", "uint16"] else 1 num_ops = 1
self.assertEqual(len(block.ops), num_ops) self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0] init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random') self.assertEqual(init_op.type, 'gaussian_random')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册