未验证 提交 4b61918d 编写于 作者: 0 0x45f 提交者: GitHub

Fix test_jit_save_load (#41114)

上级 66cf8b08
...@@ -382,7 +382,9 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -382,7 +382,9 @@ class GradNodeRunProgram : public egr::GradNodeBase {
x_grad_ptr.emplace_back(&i); x_grad_ptr.emplace_back(&i);
} }
for (auto &i : params_grad) { for (auto &i : params_grad) {
params_grad_ptr.emplace_back(&i); if (i.defined()) {
params_grad_ptr.emplace_back(&i);
}
} }
PADDLE_ENFORCE_EQ(hooked_grads[0].size(), fwd_out_names_.size(), PADDLE_ENFORCE_EQ(hooked_grads[0].size(), fwd_out_names_.size(),
......
...@@ -883,7 +883,7 @@ def _run_dygraph(instance, input, program_holder): ...@@ -883,7 +883,7 @@ def _run_dygraph(instance, input, program_holder):
# transform SelectedRows to LoDTensor forcibly, it may not # transform SelectedRows to LoDTensor forcibly, it may not
# be user wanted result. # be user wanted result.
for persistable_var in persistable_vars: for persistable_var in persistable_vars:
grad_var_name = var.name + core.grad_var_suffix() grad_var_name = persistable_var.name + core.grad_var_suffix()
grad_var = trace_program.block(0).find_var(cpt.to_bytes(grad_var_name)) grad_var = trace_program.block(0).find_var(cpt.to_bytes(grad_var_name))
# NOTE: cannot find var desc maybe not problem, # NOTE: cannot find var desc maybe not problem,
# such as in batch_norm # such as in batch_norm
......
...@@ -37,7 +37,7 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTra ...@@ -37,7 +37,7 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTra
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Block, ParamBase, Program, Variable, Parameter from paddle.fluid.framework import Block, ParamBase, Program, Variable, Parameter, EagerParamBase
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, _non_static_mode from paddle.fluid.framework import dygraph_only, _non_static_mode
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
...@@ -921,7 +921,8 @@ def save(layer, path, input_spec=None, **configs): ...@@ -921,7 +921,8 @@ def save(layer, path, input_spec=None, **configs):
param_or_buffer.name] param_or_buffer.name]
extra_info_dict[ extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient 'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase): if isinstance(param_or_buffer,
(ParamBase, EagerParamBase)):
extra_info_dict[ extra_info_dict[
'trainable'] = param_or_buffer.trainable 'trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict extra_var_info[param_or_buffer.name] = extra_info_dict
......
...@@ -147,6 +147,13 @@ class Tracer(core.Tracer): ...@@ -147,6 +147,13 @@ class Tracer(core.Tracer):
attrs_list.append(v) attrs_list.append(v)
returns = function_ptr(*arg_list, *attrs_list) returns = function_ptr(*arg_list, *attrs_list)
if type == 'load_combine':
assert len(outputs.keys()) == 1
key = list(outputs.keys())[0]
for j in range(len(returns)):
returns[j]._share_underline_tensor_to(outputs[key][j])
return
if isinstance(returns, tuple): if isinstance(returns, tuple):
for i in range(len(op_returns)): for i in range(len(op_returns)):
retname = op_returns[i] retname = op_returns[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册