diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 778bab9f4dd26829a6fa9b76b4381a41c75a3280..0858a43838b964f049a9df4b431cba6dfbe693f6 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -188,8 +188,8 @@ class Flatten2Op : public framework::OperatorWithKernel { // are the same. ctx->ShareLoD("X", "Out"); } - - OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2"); + if (!ctx->HasOutput("XShape")) return; + // OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2"); std::vector xshape_dims(in_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < in_dims.size(); ++i) { @@ -207,7 +207,8 @@ class Flatten2OpMaker : public FlattenOpMaker { AddOutput("XShape", "XShape is just used to store the shape and lod of X, which will " "be used in FlattenGradOp.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); } }; @@ -281,8 +282,8 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { // are the same. ctx->ShareLoD("X", "Out"); } - - OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2"); + if (!ctx->HasOutput("XShape")) return; + // OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2"); std::vector xshape_dims(in_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < in_dims.size(); ++i) { @@ -361,7 +362,8 @@ Case 2: AddOutput("XShape", "XShape is just used to store the shape and lod of X, which will " "be used in FlattenGradOp.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); } };