未验证 提交 b027652b 编写于 作者: G Guoxia Wang 提交者: GitHub

remove tmp fp32 var for gaussian_random (#46285)

上级 3e8b3220
......@@ -374,19 +374,6 @@ class NormalInitializer(Initializer):
["uint16", "float16", "float32", "float64"],
"guassian_random")
# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(name=unique_name.generate(".".join(
['normal_init', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
if self._seed == 0:
self._seed = block.program.random_seed
......@@ -394,48 +381,29 @@ class NormalInitializer(Initializer):
place = _current_expected_place()
out_var = _C_ops.gaussian_random(var.shape, self._mean,
self._std_dev, self._seed,
out_dtype, place)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
var.dtype, place)
out_var._share_underline_tensor_to(var)
return None
if _in_legacy_dygraph():
out_var = _legacy_C_ops.gaussian_random(
'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean,
'shape', var.shape, 'dtype', var.dtype, 'mean', self._mean,
'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
out_var._share_underline_tensor_to(var)
return None
else:
op = block.append_op(type="gaussian_random",
outputs={"Out": out_var},
outputs={"Out": var},
attrs={
"shape": var.shape,
"dtype": out_dtype,
"dtype": var.dtype,
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed,
"use_mkldnn": False
},
stop_gradient=True)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
block.append_op(type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={
"in_dtype": out_var.dtype,
"out_dtype": var.dtype
})
var.op = op
return op
......@@ -695,7 +663,7 @@ class XavierInitializer(Initializer):
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
"dtype": out_dtype,
"dtype": out_var.dtype,
"mean": 0.0,
"std": std,
"seed": self._seed
......
......@@ -245,7 +245,7 @@ class TestNormalInitializer(unittest.TestCase):
name="param",
initializer=initializer.NormalInitializer(
2.3, 1.9, 123))
num_ops = 2 if (dtype == "float16" or dtype == "uint16") else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
......@@ -390,7 +390,6 @@ class TestXavierInitializer(unittest.TestCase):
"""Test the Xavier initializer with float16
"""
block = self.test_xavier_initializer_supplied_arguments("float16")
self.assertTrue(check_cast_op(block.ops[1]))
def test_xavier_initializer_bf16(self):
"""Test the Xavier initializer with bfloat16
......@@ -400,7 +399,6 @@ class TestXavierInitializer(unittest.TestCase):
self.assertEqual(len(block_uniform.ops), 1)
block_gaussian = self.test_xavier_initializer_supplied_arguments(
"uint16", False)
self.assertTrue(check_cast_op(block_gaussian.ops[1]))
class TestMSRAInitializer(unittest.TestCase):
......
......@@ -398,7 +398,7 @@ class TestNormal(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.Normal(2.3, 1.9))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册