From 97f86d84afde2286164dd6ca757d6e4b55a7d225 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 30 Jun 2021 14:49:37 +0800 Subject: [PATCH] [Dy2Stat] Refine PartialProgramLayer logic (#33796) * refine temp_scope_vec logic * polish partial_program * fix fake var * add stop_gradient in spec * fix fake_var * fix unittest --- .../dygraph_to_static/function_spec.py | 18 ++++- .../dygraph_to_static/partial_program.py | 77 +++++++++++-------- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 031351ca11..c25574c39d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -103,8 +103,11 @@ class FunctionSpec(object): for idx, input_var in enumerate(flatten(args)): if isinstance(input_var, np.ndarray): input_var = paddle.static.InputSpec.from_numpy(input_var) + _set_spec_stop_gradient(input_var, True) elif isinstance(input_var, core.VarBase): + stop_gradient = input_var.stop_gradient input_var = paddle.static.InputSpec.from_tensor(input_var) + _set_spec_stop_gradient(input_var, stop_gradient) args_with_spec.append(input_var) @@ -172,13 +175,15 @@ class FunctionSpec(object): block = main_program.global_block() for i, var_spec in enumerate(flat_input_spec): if isinstance(var_spec, paddle.static.InputSpec): + stop_gradient = getattr(var_spec, 'stop_gradient', False) feed_layer = block.create_var( # TODO(Aurelius84): consider a more elegant way to name this name=var_spec.name or "feed_%s" % i, shape=var_spec.shape, dtype=var_spec.dtype, is_data=True, - need_check_feed=False) + need_check_feed=False, + stop_gradient=stop_gradient) else: feed_layer = var_spec inputs.append(feed_layer) @@ -302,7 +307,7 @@ def convert_to_input_spec(inputs, input_spec): if isinstance(rest_input, (core.VarBase, np.ndarray)): logging_utils.warn( "The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. " - "Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.". + "Please specific InputSpec information in `@to_static` if you expect them as mutable inputs.". format(type_name(rest_input))) input_with_spec.extend(inputs[len(input_spec):]) @@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec): return processed_specs else: return input_spec + + +def _set_spec_stop_gradient(spec, stop_gradient): + """ + Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op + while append_backward. + """ + assert isinstance(spec, paddle.static.InputSpec) + spec.stop_gradient = stop_gradient diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 84bac98013..4d12c3c2b9 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -35,6 +35,7 @@ class NestSequence(object): def __init__(self, raw_input, need_check=False): self.__raw_input = raw_input + self.__input_list = self.tolist() self.__var_ids = self._get_var_ids() self._check_non_variable(need_check) @@ -48,12 +49,12 @@ class NestSequence(object): """ Restores the nested sequence from value list. """ - assert len(self.tolist()) == len(value_list) + assert len(self.__input_list) == len(value_list) return pack_sequence_as(self.__raw_input, value_list) def _get_var_ids(self): var_ids = [] - for idx, var in enumerate(self.tolist()): + for idx, var in enumerate(self.__input_list): if isinstance(var, (framework.Variable, core.VarBase)): var_ids.append(idx) @@ -65,7 +66,7 @@ class NestSequence(object): """ if need_check: warning_types = set() - for var in self.tolist(): + for var in self.__input_list: if not isinstance(var, (framework.Variable, core.VarBase)): warning_types.add(type(var)) if warning_types: @@ -80,7 +81,7 @@ class NestSequence(object): return self.__var_ids def __getitem__(self, item): - return self.tolist()[item] + return self.__input_list[item] class LazyInitialized(object): @@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test): return program -class PartialProgramLayer(layers.Layer): +class PartialProgramLayer: """ PartialProgramLayer wraps all the ops from layers decorated by `@declarative` and execute them as a static subgraph. @@ -134,7 +135,9 @@ class PartialProgramLayer(layers.Layer): self._params = parameters if parameters is not None else [] self._origin_main_program = self._verify_program(main_program) - self._inner_scope = core.Scope() + self._tmp_scope_vec = self._create_scope_vec() + # A fake_var to handle empty input or output + self.__fake_vars = _create_fake_var() # Set default mode to train self._double_grads = self._get_double_grads(self._origin_main_program) self.training = True @@ -217,19 +220,19 @@ class PartialProgramLayer(layers.Layer): var_desc.name(), var_desc.type(), False) double_grads.append(var_base) - return double_grads + return self._valid_vars(double_grads) - def forward(self, inputs): - in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) + def __call__(self, inputs): + in_vars, out_vars = self._prepare(inputs) attrs = ('global_block', self.program.desc.block(0), 'start_op_index', 0, 'end_op_index', self._infer_program.desc.block(0).op_size(), 'is_test', not self.training) core.ops.run_program( - valid_vars(in_vars), - valid_vars(self._params), - valid_vars(out_vars), tmp_scope_vec, - valid_vars(self._double_grads), *attrs) + self._valid_vars(in_vars), + self._valid_vars(self._params), + self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads, + *attrs) restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) @@ -264,7 +267,6 @@ class PartialProgramLayer(layers.Layer): expected_place): var = value._copy_to(expected_place, False) var.stop_gradient = True - var.name = value.name else: var = value var.name = self._inputs[i].desc.name() @@ -272,25 +274,29 @@ class PartialProgramLayer(layers.Layer): continue input_vars.append(var) - # Create VarBase to receive output data. - out_vars = [] - for idx in self._outputs.var_ids: - var = self._outputs[idx] + def create_out(var_id): + var = self._outputs[var_id] assert isinstance(var, framework.Variable) var_desc = var.desc var_base = core.VarBase(var_desc.dtype(), var_desc.shape(), var_desc.name(), var_desc.type(), False) - out_vars.append(var_base) + return var_base + + # Create VarBase to receive output data. + out_vars = list(map(create_out, self._outputs.var_ids)) + + return input_vars, out_vars + def _create_scope_vec(self): # Hold forward variables tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], "program_out_scope", core.VarDesc.VarType.STEP_SCOPES, True) - tmp_scope_vec.value().set_scope(self._inner_scope) - - return input_vars, out_vars, tmp_scope_vec + inner_scope = core.Scope() + tmp_scope_vec.value().set_scope(inner_scope) + return tmp_scope_vec def _restore_out(self, out_vars): """ @@ -311,8 +317,9 @@ class PartialProgramLayer(layers.Layer): return main_program.clone(for_test=True) def _is_no_value(self, var): - if isinstance(var, core.VarBase): - if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: + if isinstance(var, core.VarBase) and var.shape == [1]: + # NOTE: .numpy() will insert MemcpySync operation, it hits performance. + if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: return True return False @@ -405,20 +412,22 @@ class PartialProgramLayer(layers.Layer): "Please define the layer with parameters in `__init__` function." % name) + def _valid_vars(self, vars): + """ + Note: run_program_op.InferShape requires `X`/'Out' not be null. + But it's common in dy2static, fake varBase is created to handle the + problem. + """ + return vars if vars else self.__fake_vars + -def valid_vars(vars): +def _create_fake_var(): """ - Note: run_program_op.InferShape requires `X`/'Out' not be null. - But it's common in dy2static, fake varBase is created to handle the - problem. + Create a fake_var (force on CPU) to handle empty input or output """ - if vars: - return vars return [ - core.VarBase( - value=[1], - name='Fake_var', - place=framework._current_expected_place()) + core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var", + core.VarDesc.VarType.RAW, False) ] -- GitLab