diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index af10c65400ee2c90c8281faf5b06cc2d8a367626..0a9e66a5bb0b1b7f15928891f8eefcbc67ebffb5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -112,14 +112,7 @@ class PartialProgramLayer(layers.Layer): self._outputs = NestSequence(outputs, need_check=True) self._params = parameters if parameters is not None else [] - # Check all params from main program can be found in self._params: - # 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph. - # 2. parameter from transformed program shall be found in self._params. - # Because they share same data with ParamBase of original dygraph. - self._check_params_all_inited(main_program) - self._prune_unused_params(main_program) - - self._infer_program = main_program + self._infer_program = self._verify_program(main_program) self._train_program = self._append_backward_desc() # Switch infer or train by train() and eval() self._trace_program = None @@ -128,6 +121,20 @@ class PartialProgramLayer(layers.Layer): # Set default mode to train self.train() + def _verify_program(self, main_program): + """ + Verify that the program parameter is initialized, prune some unused params, + and remove redundant op callstack. + """ + # 1. Check all params from main program can be found in self._params + self._check_params_all_inited(main_program) + # 2. Prune the parameters not used anywhere in the program. + self._prune_unused_params(main_program) + # 3. Remove op's python call stack with redundant low-level error messages. + main_program = self._remove_op_call_stack(main_program) + + return main_program + @switch_to_static_graph def _append_backward_desc(self): program = self._infer_program.clone() @@ -295,6 +302,19 @@ class PartialProgramLayer(layers.Layer): continue param._set_grad_type(grad_var.type()) + def _remove_op_call_stack(self, main_program): + """ + Remove op's python call stack with redundant low-level error messages related to + transforamtions to avoid confusing users. + """ + assert isinstance(main_program, framework.Program) + for block in main_program.blocks: + for op in block.ops: + if op.has_attr("op_callstack"): + op._remove_attr("op_callstack") + + return main_program + def _check_params_all_inited(self, main_program): """ Check all params from main program are already initialized, see details as follows: