diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 6aaa251ead9f9362c6d18f4500051f0903854315..530b3560bb878a3b2ff5c1655e4d5de012df9f02 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -283,123 +283,6 @@ class Flatten2GradOp : public framework::OperatorWithKernel { } }; -class FlattenContiguousRangeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "FlattenContiguousRange"); - const auto &start_axis = ctx->Attrs().Get("start_axis"); - const auto &stop_axis = ctx->Attrs().Get("stop_axis"); - - // Construct MetaTensor for InferMeta Func - using CompatMetaTensor = framework::CompatMetaTensor; - CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime()); - CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime()); - std::unique_ptr xshape(nullptr); - if (ctx->HasOutput("XShape")) { - xshape = std::move(std::unique_ptr(new CompatMetaTensor( - ctx->GetOutputVarPtrs("XShape")[0], ctx->IsRuntime()))); - } - phi::FlattenWithXShapeInferMeta( - x, start_axis, stop_axis, &out, xshape.get()); - } -}; - -class FlattenContiguousRangeOpMaker : public FlattenOpMaker { - public: - void Make() override { - AddInput("X", "(Tensor) A tensor of rank >= axis."); - AddOutput("Out", - "A 2D tensor is reshaped input tensor. The input dimensions" - "up to axis are flattened to the outer dimension of the output" - "and the remaining input dimensions are flattened into the inner" - "dimension of the output."); - AddAttr("start_axis", - "(int)" - "Indicate the input start dimension (exclusive) to flatten") - .SetDefault(1); - AddAttr("stop_axis", - "(int)" - "Indicate the input stop dimension (exclusive) to flatten") - .SetDefault(1); - AddComment(R"DOC( -Flatten Operator - -Flattens the input tensor into a new matrix according to start_axis and stop_axis. - -Examples: -Case 1: - Given - X.shape = (3, 100, 100, 4) - and - start_axis = 2, stop_axis = -1 - We get: - Out.shape = (3, 100, 400) - -Case 2: - Given - X.shape = (3, 100, 100, 4) - and - start_axis = 0, stop_axis = -1 - We get: - Out.shape = (3 * 100 * 100 * 4) -)DOC"); - AddOutput("XShape", - "XShape is just used to store the shape and lod of X, which will " - "be used in FlattenGradOp.") - .AsIntermediate() - .AsExtra(); - } -}; - -template -class FlattenContiguousRangeGradOpMaker - : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("flatten_contiguous_range_grad"); - grad_op->SetInput("XShape", this->Output("XShape")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("XShape"), - "Input", - "XShape", - "FlattenContiguousRangeGrad"); - OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "FlattenContiguousRangeGrad"); - // Construct MetaTensor for InferMeta Func - using CompatMetaTensor = framework::CompatMetaTensor; - CompatMetaTensor xshape(context->GetInputVarPtrs("XShape")[0], - context->IsRuntime()); - CompatMetaTensor dx( - context->GetOutputVarPtrs(framework::GradVarName("X"))[0], - context->IsRuntime()); - phi::KernelWithXShapeInferMeta(xshape, &dx); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } -}; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, {framework::GradVarName("Out"), @@ -431,17 +314,6 @@ REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, ops::FlattenGradInplaceInferer); -REGISTER_OPERATOR( - flatten_contiguous_range, - ops::FlattenContiguousRangeOp, - ops::FlattenContiguousRangeOpMaker, - ops::FlattenContiguousRangeGradOpMaker, - ops::FlattenContiguousRangeGradOpMaker, - ops::FlattenOpInplaceInferer); -REGISTER_OPERATOR(flatten_contiguous_range_grad, - ops::FlattenContiguousRangeGradOp, - ops::FlattenGradInplaceInferer); - REGISTER_OP_CPU_KERNEL(flatten, ops::FlattenKernel, ops::FlattenKernel, diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index f4608f008535cacf0fc4f82cb8d836d424383c06..7bf3b5cd2fcd89a6ba13b4c6990074765335c33d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -627,6 +627,18 @@ func : flash_attn_unpadded_grad data_type: q +- backward_op : flatten_grad + forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape) + args : (Tensor xshape, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : KernelWithXShapeInferMeta + param : [xshape] + kernel : + func : flatten_grad + data_type : out_grad + inplace : (out_grad -> x_grad) + - backward_op : flip_grad forward : flip (Tensor x, int[] axis) -> Tensor(out) args : (Tensor out_grad, int[] axis) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 181b819cde96ee2f1b1005abfcc466a236e522e1..4ba99b1b8131207b220095d6fdd438a0c697005e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -397,20 +397,6 @@ func : fill_grad inplace : (out_grad -> x_grad) -- backward_op : flatten_grad - forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape) - args : (Tensor xshape, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : KernelWithXShapeInferMeta - param : [xshape] - kernel : - func : flatten_grad - data_type: out_grad - backend: out_grad - layout: out_grad - inplace : (out_grad -> x_grad) - - backward_op : fmax_grad forward : fmax(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 89aef2203ccaf2d6b78fed94858f26a23b0e20c9..53ae099e762ead64a57d4fe360e25caecb6716e7 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -508,19 +508,6 @@ inplace : (x -> out) backward: fill_grad -- op : flatten - args : (Tensor x, int start_axis, int stop_axis) - output : Tensor(out), Tensor(xshape) - infer_meta : - func : FlattenWithXShapeInferMeta - kernel : - func : flatten - backend : x - inplace : (x -> out) - view : (x -> out) - intermediate : xshape - backward : flatten_grad - - op : floor_divide args : (Tensor x, Tensor y) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 8e36affbb7596b1878315b377dea7bc57821da57..f807a3d748ba101f5c4cb9da898e3639623e1288 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -806,12 +806,16 @@ out : Out - op : flatten (flatten_contiguous_range) + backward : flatten_grad (flatten_contiguous_range_grad) inputs : x : X outputs : {out : Out, xshape : XShape} attrs : {start_axis : start_axis, stop_axis : stop_axis} + extra : + outputs : [xshape] + manual_signature : [flatten, flatten_grad] - op : flip inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8de00fc785ca0a64266cf65ad5923ebc5998512f..3afbf00c049e640e52bb51171dc6a1c410199018 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -660,6 +660,19 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad +- op : flatten + args : (Tensor x, int start_axis = 1, int stop_axis = 1) + output : Tensor(out), Tensor(xshape) + infer_meta : + func : FlattenWithXShapeInferMeta + kernel : + func : flatten + data_type : x + inplace : (x -> out) + view : (x -> out) + intermediate : xshape + backward : flatten_grad + - op : flip args : (Tensor x, int[] axis) output : Tensor (out) diff --git a/paddle/phi/ops/compat/flatten_sig.cc b/paddle/phi/ops/compat/flatten_sig.cc index b225dc625240b9f5526dcedd4f7c0d06f7dbcf8f..cd3ccd136de29a5fcc71e07777436f17e45a3038 100644 --- a/paddle/phi/ops/compat/flatten_sig.cc +++ b/paddle/phi/ops/compat/flatten_sig.cc @@ -17,6 +17,10 @@ limitations under the License. */ namespace phi { KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsForInferShape()) { + return KernelSignature( + "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); + } if (ctx.HasOutput("XShape")) { return KernelSignature( "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"});