未验证 提交 fafc7be2 编写于 作者: Z zyfncg 提交者: GitHub

Clip intermediate output of op when save inference model (#48026)

* clip extra and intermediate output of op

* fix bug

* fix bug

* polich code

* polich log
上级 ef51bbfd
......@@ -511,7 +511,7 @@ void OperatorBase::CheckAllInputOutputSet() const {
}
for (auto& out : info_->Proto().outputs()) {
if (!out.dispensable() && !out.extra()) {
if (!out.dispensable() && !out.extra() && !out.intermediate()) {
PADDLE_ENFORCE_NE(
outputs_.find(out.name()),
outputs_.end(),
......
......@@ -158,12 +158,14 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
bias_dim[0]));
}
ctx->SetOutputDim("Y", x_dims);
ctx->ShareLoD("X", "Y");
VLOG(4) << x_dims;
ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C});
if (!test_mode) {
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
......
......@@ -518,10 +518,7 @@ class Reshape2Op : public ReshapeOp {
const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"),
true,
platform::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
if (ctx->HasOutput("XShape")) {
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
......@@ -530,7 +527,7 @@ class Reshape2Op : public ReshapeOp {
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
ReshapeOp::InferShape(ctx);
}
};
......
......@@ -25,7 +25,7 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.layers import nn
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import _non_static_mode, OpProtoHolder
from paddle.fluid.executor import (
_is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
......@@ -563,6 +563,35 @@ class _ProgramHolder:
stop_gradient=True,
)
op.desc.set_output("ReserveSpace", [reserve_space.name])
continue
proto = OpProtoHolder.instance().get_op_proto(op.type)
has_create_intermediate_out = False
for output_proto in proto.outputs:
if output_proto.intermediate:
intermediate_name = output_proto.name
if intermediate_name not in op.output_names:
has_create_intermediate_out = True
intermediate_var = block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(
[
op.type + '_' + intermediate_name,
'tmp',
]
)
),
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True,
)
op.desc.set_output(
intermediate_name, [intermediate_var.name]
)
if has_create_intermediate_out:
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
return program
@switch_to_static_graph
......
......@@ -6175,8 +6175,8 @@ class Program:
if not find:
remove_output_list.append(name)
# The extra output of op will be removed in the future
# for name in remove_output_list:
# op.remove_output(name)
for name in remove_output_list:
op.remove_output(name)
op_quant_name = (
core.op_proto_and_checker_maker.kOpWithQuantAttrName()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册