未验证 提交 cf3c51a6 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Remove op call stack in PartialProgram (#25420)

* remove op call stack test=develop

* fix typo test=develop
上级 848aca7a
......@@ -112,14 +112,7 @@ class PartialProgramLayer(layers.Layer):
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []
# Check all params from main program can be found in self._params:
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
# 2. parameter from transformed program shall be found in self._params.
# Because they share same data with ParamBase of original dygraph.
self._check_params_all_inited(main_program)
self._prune_unused_params(main_program)
self._infer_program = main_program
self._infer_program = self._verify_program(main_program)
self._train_program = self._append_backward_desc()
# Switch infer or train by train() and eval()
self._trace_program = None
......@@ -128,6 +121,20 @@ class PartialProgramLayer(layers.Layer):
# Set default mode to train
self.train()
def _verify_program(self, main_program):
"""
Verify that the program parameter is initialized, prune some unused params,
and remove redundant op callstack.
"""
# 1. Check all params from main program can be found in self._params
self._check_params_all_inited(main_program)
# 2. Prune the parameters not used anywhere in the program.
self._prune_unused_params(main_program)
# 3. Remove op's python call stack with redundant low-level error messages.
main_program = self._remove_op_call_stack(main_program)
return main_program
@switch_to_static_graph
def _append_backward_desc(self):
program = self._infer_program.clone()
......@@ -295,6 +302,19 @@ class PartialProgramLayer(layers.Layer):
continue
param._set_grad_type(grad_var.type())
def _remove_op_call_stack(self, main_program):
"""
Remove op's python call stack with redundant low-level error messages related to
transforamtions to avoid confusing users.
"""
assert isinstance(main_program, framework.Program)
for block in main_program.blocks:
for op in block.ops:
if op.has_attr("op_callstack"):
op._remove_attr("op_callstack")
return main_program
def _check_params_all_inited(self, main_program):
"""
Check all params from main program are already initialized, see details as follows:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册