未验证 提交 08b7f17d 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Fix benchmark Performance (#38610)

上级 ba411960
...@@ -160,9 +160,9 @@ class ConstantInitializer(Initializer): ...@@ -160,9 +160,9 @@ class ConstantInitializer(Initializer):
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
# fill constant should set the "str_value" to preserve precision # fill constant should set the "str_value" to preserve precision
...@@ -279,9 +279,9 @@ class UniformInitializer(Initializer): ...@@ -279,9 +279,9 @@ class UniformInitializer(Initializer):
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
op = block.append_op( op = block.append_op(
...@@ -382,9 +382,9 @@ class NormalInitializer(Initializer): ...@@ -382,9 +382,9 @@ class NormalInitializer(Initializer):
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
op = block.append_op( op = block.append_op(
...@@ -477,9 +477,9 @@ class TruncatedNormalInitializer(Initializer): ...@@ -477,9 +477,9 @@ class TruncatedNormalInitializer(Initializer):
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
op = block.append_op( op = block.append_op(
...@@ -617,9 +617,9 @@ class XavierInitializer(Initializer): ...@@ -617,9 +617,9 @@ class XavierInitializer(Initializer):
var.dtype == VarDesc.VarType.BF16 and not self._uniform): var.dtype == VarDesc.VarType.BF16 and not self._uniform):
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
if self._uniform: if self._uniform:
...@@ -770,9 +770,9 @@ class MSRAInitializer(Initializer): ...@@ -770,9 +770,9 @@ class MSRAInitializer(Initializer):
var.dtype == VarDesc.VarType.BF16 and not self._uniform): var.dtype == VarDesc.VarType.BF16 and not self._uniform):
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
if self._uniform: if self._uniform:
...@@ -938,9 +938,9 @@ class BilinearInitializer(Initializer): ...@@ -938,9 +938,9 @@ class BilinearInitializer(Initializer):
]: ]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
op = block.append_op( op = block.append_op(
...@@ -1044,9 +1044,9 @@ class NumpyArrayInitializer(Initializer): ...@@ -1044,9 +1044,9 @@ class NumpyArrayInitializer(Initializer):
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype) 'out_dtype', var.dtype)
var.copy_(var_tmp, True) var.copy_(var_tmp, False)
else: else:
var.copy_(out_var, True) var.copy_(out_var, False)
return None return None
else: else:
op = block.append_op( op = block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册