未验证 提交 98900b35 编写于 作者: Z zhouzj 提交者: GitHub

recover transpose's XShape attr. (#1625)

上级 be558e8d
...@@ -31,7 +31,7 @@ def _remove_fetch_node(program): ...@@ -31,7 +31,7 @@ def _remove_fetch_node(program):
removed += 1 removed += 1
def _recover_reserve_space_with_bn(program): def _recover_outputs_attr(program):
"""Add the outputs which is only used for training and not saved in """Add the outputs which is only used for training and not saved in
inference program.""" inference program."""
for block_idx in six.moves.range(program.num_blocks): for block_idx in six.moves.range(program.num_blocks):
...@@ -49,6 +49,18 @@ def _recover_reserve_space_with_bn(program): ...@@ -49,6 +49,18 @@ def _recover_reserve_space_with_bn(program):
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name]) op.desc.set_output("ReserveSpace", [reserve_space.name])
if op.type == 'transpose2':
if 'XShape' not in op.output_names:
xshape = block.create_var(
name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(["xshape", 'tmp'
])),
dtype=block.var(op.input("X")[0]).dtype,
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
shape=(0, ) + block.var(op.input("X")[0]).shape,
persistable=False,
stop_gradient=True)
op.desc.set_output("XShape", [xshape.name])
return program return program
...@@ -70,7 +82,7 @@ def recover_inference_program(inference_program): ...@@ -70,7 +82,7 @@ def recover_inference_program(inference_program):
""" recover inference program to train program which can be trained. """ """ recover inference program to train program which can be trained. """
_remove_fetch_node(inference_program) _remove_fetch_node(inference_program)
inference_program = _recover_param_attr(inference_program) inference_program = _recover_param_attr(inference_program)
inference_program = _recover_reserve_space_with_bn(inference_program) inference_program = _recover_outputs_attr(inference_program)
for var in inference_program.list_vars(): for var in inference_program.list_vars():
var.stop_gradient = False var.stop_gradient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册