未验证 提交 d128a695 编写于 作者: W WangZhen 提交者: GitHub

Fix rename error in fp16 case (#56590)

上级 0d47f387
......@@ -33,6 +33,7 @@ from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
backend_guard,
construct_grad_names,
tensor_name_guard,
)
__all__ = []
......@@ -223,30 +224,25 @@ class PartialProgramLayer:
Execute static graph by Interpreter and Return dynamic Tensors.
"""
with UniqueNameGuard(self._name_generator):
in_vars, out_vars, in_var_names, resume_name_record = self._prepare(
inputs
)
in_vars, out_vars, in_var_names = self._prepare(inputs)
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
attrs.extend(["x_names", in_var_names])
self._sync_lr_value_with_scheduler()
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
for var in in_vars:
if var.name in resume_name_record:
var.name = resume_name_record[var.name]
with tensor_name_guard(in_vars, in_var_names):
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
self._update_stop_gradient(out_vars)
restored_nest_out = self._restore_out(out_vars)
......@@ -905,7 +901,6 @@ class PartialProgramLayer:
# Convert variable into Tensor and feed in training data.
input_vars = []
input_var_names = []
resume_name_record = {}
expected_place = framework._current_expected_place()
for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray):
......@@ -928,8 +923,6 @@ class PartialProgramLayer:
var.stop_gradient = True
else:
var = value
resume_name_record[self._inputs[i].desc.name()] = var.name
var.name = self._inputs[i].desc.name()
else:
continue
input_var_names.append(self._inputs[i].desc.name())
......@@ -960,7 +953,7 @@ class PartialProgramLayer:
# Create Tensor to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
return input_vars, out_vars, input_var_names, resume_name_record
return input_vars, out_vars, input_var_names
def _create_scope_vec(self, program_id=None, use_scope_cache=False):
# Hold forward variables
......
......@@ -1527,3 +1527,16 @@ def construct_grad_names(grad_info_map, x_vars, param_vars, out_vars):
out_grad_vars = backward._get_grad_vars(grad_info_map, out_vars)
grad_var_names['out'] = list(map(fn, out_grad_vars))
return grad_var_names
@signature_safe_contextmanager
def tensor_name_guard(tensors, names):
try:
assert len(tensors) == len(names)
origin_names = [t.name for t in tensors]
for t, name in zip(tensors, names):
t.name = name
yield
finally:
for t, name in zip(tensors, origin_names):
t.name = name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册