未验证 提交 08b7f17d 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Fix benchmark Performance (#38610)

上级 ba411960
......@@ -160,9 +160,9 @@ class ConstantInitializer(Initializer):
if var.dtype == VarDesc.VarType.FP16:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
# fill constant should set the "str_value" to preserve precision
......@@ -279,9 +279,9 @@ class UniformInitializer(Initializer):
if var.dtype == VarDesc.VarType.FP16:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
op = block.append_op(
......@@ -382,9 +382,9 @@ class NormalInitializer(Initializer):
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
op = block.append_op(
......@@ -477,9 +477,9 @@ class TruncatedNormalInitializer(Initializer):
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
op = block.append_op(
......@@ -617,9 +617,9 @@ class XavierInitializer(Initializer):
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
if self._uniform:
......@@ -770,9 +770,9 @@ class MSRAInitializer(Initializer):
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
if self._uniform:
......@@ -938,9 +938,9 @@ class BilinearInitializer(Initializer):
]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
op = block.append_op(
......@@ -1044,9 +1044,9 @@ class NumpyArrayInitializer(Initializer):
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, True)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, True)
var.copy_(out_var, False)
return None
else:
op = block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册