未验证 提交 202b0eaf 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unset ReserveSpace of batch_norm for inference program. (#32493)

* Unset ReserveSpace for inference program.

* Support training from an inference program.
上级 41bfec8d
......@@ -447,6 +447,11 @@ void OpDesc::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args;
}
void OpDesc::RemoveOutput(const std::string &name) {
outputs_.erase(name);
need_update_ = true;
}
bool OpDesc::HasProtoAttr(const std::string &name) const {
auto &op_info = OpInfoMap::Instance();
if (op_info.Has(desc_.type())) {
......
......@@ -65,6 +65,7 @@ class OpDesc {
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
void RemoveOutput(const std::string &name);
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
......
......@@ -235,6 +235,7 @@ void BindOpDesc(pybind11::module *m) {
const std::vector<std::string> &vec_var_name) {
self.SetOutput(name, vec_var_name);
})
.def("remove_output", &pd::OpDesc::RemoveOutput)
.def("input_arg_names", &pd::OpDesc::InputArgumentNames)
.def("output_arg_names", &pd::OpDesc::OutputArgumentNames)
.def("_rename_input", &pd::OpDesc::RenameInput)
......
......@@ -413,6 +413,23 @@ class _ProgramHolder(object):
# Therefore, in order to reuse the method of backward.py, build the program here.
program = _build_program_by_desc(program_desc_copy)
# 3. Add the outputs which is only used for training and not saved in
# inference program.
for block_idx in six.moves.range(program.num_blocks):
block = program.block(block_idx)
for op in block.ops:
if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
targets = []
for out in self._output_descs:
targets.append(program.global_block().var(out.name()))
......
......@@ -5021,6 +5021,9 @@ class Program(object):
op = block.op(j)
if op.has_attr('is_test'):
op._set_attr('is_test', True)
if op.type() == "batch_norm":
# Remove the output ReserveSpace of batch_norm if exists.
op.remove_output("ReserveSpace")
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册