diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 50838bb561890d0380218ddc26c42838d42e499b..44478604781df51faf5a3932f3a455b37ae170dc 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -32,6 +32,7 @@ from paddle.fluid.layers.utils import _hash_with_id, flatten, pack_sequence_as from . import logging_utils from .return_transformer import RETURN_NO_VALUE_MAGIC_NUM +from .utils import _out_grad_names, _param_grad_names __all__ = [] @@ -375,46 +376,15 @@ class PartialProgramLayer: @LazyInitialized def _param_grad_names(self): - names = [] - # NOTE: `names` and `self._params` must be in the same order so that - # the param grad name can be set correctly in the run_program. - for param in self._params: - candidate = [ - var_name - for var_name in self._train_program.block(0).vars.keys() - if var_name.endswith(param.name + '@GRAD') - ] - if candidate: - names.append( - max(candidate, key=lambda name: name.count('grad/')) - ) - else: - names.append(param.name + '@GRAD') - return names + return _param_grad_names(self._train_program.desc, self._params) @LazyInitialized def _out_grad_names(self): - """ - Parse Out@GARD name from original train and infer program. - """ - names = [] - origin_infer_program = self._create_program(is_infer_mode=True) - origin_train_program = self._train_program - fwd_end_op_index = len(origin_infer_program.block(0).ops) - for i in range( - fwd_end_op_index + 1, - min( - fwd_end_op_index + 2 * len(self._outputs.var_ids), - len(origin_train_program.block(0).ops), - ), - 2, - ): - op = origin_train_program.block(0).ops[i] - if op.type == 'fill_constant': - var_name = op.output('Out')[0] - names.append(var_name) - - return names + return _out_grad_names( + self._train_program.desc, + self._create_program(is_infer_mode=True).desc.block(0).op_size(), + len(self._outputs.var_ids), + ) @property def program(self): diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 4d74c629a392683bd63e71ef2a5806f186354daf..4397728576ba755a618f706954cda16b45d3f6aa 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1483,3 +1483,41 @@ def create_name_str(name_ids): names_str = ["'%s'" % (name.replace("'", "\\'")) for name in name_ids] return "(%s, )" % ','.join(names_str) + + +def _param_grad_names(program_desc, params): + """ + Parse PARAM@GARD name from original train and infer program. + """ + names = [] + # NOTE: `names` and `self._params` must be in the same order so that + # the param grad name can be set correctly in the run_program. + for param in params: + candidate = [ + var.name() + for var in program_desc.block(0).all_vars() + if var.name().endswith(param.name + '@GRAD') + ] + if candidate: + names.append(max(candidate, key=lambda name: name.count('grad/'))) + else: + names.append(param.name + '@GRAD') + + return names + + +def _out_grad_names(program_desc, fwd_end_op_index, out_size): + """ + Parse Out@GARD name from original train and infer program. + """ + names = [] + for i in range( + fwd_end_op_index + 1, + min(fwd_end_op_index + 2 * out_size, program_desc.block(0).op_size()), + 2, + ): + op = program_desc.block(0).op(i) + if op.type() == 'fill_constant': + var_name = op.output('Out')[0] + names.append(var_name) + return names diff --git a/python/paddle/jit/translated_layer.py b/python/paddle/jit/translated_layer.py index 9cd30545af8450343981aeb020f451b6c2f235bb..c488c758f4a262fdb4537f8d9a2c26d6248e4bbe 100644 --- a/python/paddle/jit/translated_layer.py +++ b/python/paddle/jit/translated_layer.py @@ -33,6 +33,8 @@ from paddle.jit.dy2static.partial_program import ( add_build_strategy_for, ) +from .dy2static.utils import _out_grad_names, _param_grad_names + __all__ = [] INFER_MODEL_SUFFIX = ".pdmodel" @@ -887,28 +889,7 @@ def _construct_params_and_buffers( def _valid_vars(vars): - if vars: - return vars - if framework._in_eager_without_dygraph_check(): - return [ - core.eager.Tensor( - core.VarDesc.VarType.FP32, - [], - "Fake_var", - core.VarDesc.VarType.RAW, - False, - ) - ] - else: - return [ - core.VarBase( - core.VarDesc.VarType.FP32, - [], - "Fake_var", - core.VarDesc.VarType.RAW, - False, - ) - ] + return vars if vars else None def _run_dygraph(instance, input, program_holder): @@ -1041,6 +1022,15 @@ def _run_dygraph(instance, input, program_holder): 'program_id', _hash_with_id(trace_program, instance), ] + if not instance._is_test: + attrs.extend( + ( + 'param_grad_names', + _param_grad_names(trace_program, persistable_vars), + 'out_grad_names', + _out_grad_names(trace_program, end_op_index, len(output_vars)), + ) + ) use_interpretorcore = ( _is_enable_standalone_executor()