未验证 提交 fafc7be2 编写于 作者: Z zyfncg 提交者: GitHub

Clip intermediate output of op when save inference model (#48026)

* clip extra and intermediate output of op

* fix bug

* fix bug

* polich code

* polich log
上级 ef51bbfd
...@@ -511,7 +511,7 @@ void OperatorBase::CheckAllInputOutputSet() const { ...@@ -511,7 +511,7 @@ void OperatorBase::CheckAllInputOutputSet() const {
} }
for (auto& out : info_->Proto().outputs()) { for (auto& out : info_->Proto().outputs()) {
if (!out.dispensable() && !out.extra()) { if (!out.dispensable() && !out.extra() && !out.intermediate()) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
outputs_.find(out.name()), outputs_.find(out.name()),
outputs_.end(), outputs_.end(),
......
...@@ -158,12 +158,14 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -158,12 +158,14 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
bias_dim[0])); bias_dim[0]));
} }
ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Y", x_dims);
ctx->ShareLoD("X", "Y");
VLOG(4) << x_dims; VLOG(4) << x_dims;
ctx->SetOutputDim("MeanOut", {C}); ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C}); ctx->SetOutputDim("VarianceOut", {C});
if (!test_mode) {
ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y"); }
if (ctx->HasOutput("ReserveSpace")) { if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1}); ctx->SetOutputDim("ReserveSpace", {-1});
} }
......
...@@ -518,10 +518,7 @@ class Reshape2Op : public ReshapeOp { ...@@ -518,10 +518,7 @@ class Reshape2Op : public ReshapeOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {} : ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), if (ctx->HasOutput("XShape")) {
true,
platform::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
...@@ -530,7 +527,7 @@ class Reshape2Op : public ReshapeOp { ...@@ -530,7 +527,7 @@ class Reshape2Op : public ReshapeOp {
} }
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims)); ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape"); ctx->ShareLoD("X", /*->*/ "XShape");
}
ReshapeOp::InferShape(ctx); ReshapeOp::InferShape(ctx);
} }
}; };
......
...@@ -25,7 +25,7 @@ from paddle.fluid.dygraph import layers ...@@ -25,7 +25,7 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.layers import nn from paddle.fluid.layers import nn
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode, OpProtoHolder
from paddle.fluid.executor import ( from paddle.fluid.executor import (
_is_enable_standalone_executor, _is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor, _is_dy2st_enable_standalone_executor,
...@@ -563,6 +563,35 @@ class _ProgramHolder: ...@@ -563,6 +563,35 @@ class _ProgramHolder:
stop_gradient=True, stop_gradient=True,
) )
op.desc.set_output("ReserveSpace", [reserve_space.name]) op.desc.set_output("ReserveSpace", [reserve_space.name])
continue
proto = OpProtoHolder.instance().get_op_proto(op.type)
has_create_intermediate_out = False
for output_proto in proto.outputs:
if output_proto.intermediate:
intermediate_name = output_proto.name
if intermediate_name not in op.output_names:
has_create_intermediate_out = True
intermediate_var = block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(
[
op.type + '_' + intermediate_name,
'tmp',
]
)
),
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True,
)
op.desc.set_output(
intermediate_name, [intermediate_var.name]
)
if has_create_intermediate_out:
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
return program return program
@switch_to_static_graph @switch_to_static_graph
......
...@@ -6175,8 +6175,8 @@ class Program: ...@@ -6175,8 +6175,8 @@ class Program:
if not find: if not find:
remove_output_list.append(name) remove_output_list.append(name)
# The extra output of op will be removed in the future # The extra output of op will be removed in the future
# for name in remove_output_list: for name in remove_output_list:
# op.remove_output(name) op.remove_output(name)
op_quant_name = ( op_quant_name = (
core.op_proto_and_checker_maker.kOpWithQuantAttrName() core.op_proto_and_checker_maker.kOpWithQuantAttrName()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册