未验证 提交 98900b35 编写于 作者: Z zhouzj 提交者: GitHub

recover transpose's XShape attr. (#1625)

上级 be558e8d
......@@ -31,7 +31,7 @@ def _remove_fetch_node(program):
removed += 1
def _recover_reserve_space_with_bn(program):
def _recover_outputs_attr(program):
"""Add the outputs which is only used for training and not saved in
inference program."""
for block_idx in six.moves.range(program.num_blocks):
......@@ -49,6 +49,18 @@ def _recover_reserve_space_with_bn(program):
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
if op.type == 'transpose2':
if 'XShape' not in op.output_names:
xshape = block.create_var(
name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(["xshape", 'tmp'
])),
dtype=block.var(op.input("X")[0]).dtype,
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
shape=(0, ) + block.var(op.input("X")[0]).shape,
persistable=False,
stop_gradient=True)
op.desc.set_output("XShape", [xshape.name])
return program
......@@ -70,7 +82,7 @@ def recover_inference_program(inference_program):
""" recover inference program to train program which can be trained. """
_remove_fetch_node(inference_program)
inference_program = _recover_param_attr(inference_program)
inference_program = _recover_reserve_space_with_bn(inference_program)
inference_program = _recover_outputs_attr(inference_program)
for var in inference_program.list_vars():
var.stop_gradient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册