From 410e25fb9090a31f2ff91a816432317fdda87379 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Tue, 11 Apr 2023 11:08:39 +0800 Subject: [PATCH] support auto generate for flatten (flatten_contiguous_range) (#52512) * support auto generate for flatten (flatten_contiguous_range) * add data_type for flatten_grad --- paddle/fluid/operators/flatten_op.cc | 128 ----------------------- paddle/phi/api/yaml/backward.yaml | 12 +++ paddle/phi/api/yaml/legacy_backward.yaml | 14 --- paddle/phi/api/yaml/legacy_ops.yaml | 13 --- paddle/phi/api/yaml/op_compat.yaml | 4 + paddle/phi/api/yaml/ops.yaml | 13 +++ paddle/phi/ops/compat/flatten_sig.cc | 4 + 7 files changed, 33 insertions(+), 155 deletions(-) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 6aaa251ead9..530b3560bb8 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -283,123 +283,6 @@ class Flatten2GradOp : public framework::OperatorWithKernel { } }; -class FlattenContiguousRangeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "FlattenContiguousRange"); - const auto &start_axis = ctx->Attrs().Get("start_axis"); - const auto &stop_axis = ctx->Attrs().Get("stop_axis"); - - // Construct MetaTensor for InferMeta Func - using CompatMetaTensor = framework::CompatMetaTensor; - CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime()); - CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime()); - std::unique_ptr xshape(nullptr); - if (ctx->HasOutput("XShape")) { - xshape = std::move(std::unique_ptr(new CompatMetaTensor( - ctx->GetOutputVarPtrs("XShape")[0], ctx->IsRuntime()))); - } - phi::FlattenWithXShapeInferMeta( - x, start_axis, stop_axis, &out, xshape.get()); - } -}; - -class FlattenContiguousRangeOpMaker : public FlattenOpMaker { - public: - void Make() override { - AddInput("X", "(Tensor) A tensor of rank >= axis."); - AddOutput("Out", - "A 2D tensor is reshaped input tensor. The input dimensions" - "up to axis are flattened to the outer dimension of the output" - "and the remaining input dimensions are flattened into the inner" - "dimension of the output."); - AddAttr("start_axis", - "(int)" - "Indicate the input start dimension (exclusive) to flatten") - .SetDefault(1); - AddAttr("stop_axis", - "(int)" - "Indicate the input stop dimension (exclusive) to flatten") - .SetDefault(1); - AddComment(R"DOC( -Flatten Operator - -Flattens the input tensor into a new matrix according to start_axis and stop_axis. - -Examples: -Case 1: - Given - X.shape = (3, 100, 100, 4) - and - start_axis = 2, stop_axis = -1 - We get: - Out.shape = (3, 100, 400) - -Case 2: - Given - X.shape = (3, 100, 100, 4) - and - start_axis = 0, stop_axis = -1 - We get: - Out.shape = (3 * 100 * 100 * 4) -)DOC"); - AddOutput("XShape", - "XShape is just used to store the shape and lod of X, which will " - "be used in FlattenGradOp.") - .AsIntermediate() - .AsExtra(); - } -}; - -template -class FlattenContiguousRangeGradOpMaker - : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("flatten_contiguous_range_grad"); - grad_op->SetInput("XShape", this->Output("XShape")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("XShape"), - "Input", - "XShape", - "FlattenContiguousRangeGrad"); - OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "FlattenContiguousRangeGrad"); - // Construct MetaTensor for InferMeta Func - using CompatMetaTensor = framework::CompatMetaTensor; - CompatMetaTensor xshape(context->GetInputVarPtrs("XShape")[0], - context->IsRuntime()); - CompatMetaTensor dx( - context->GetOutputVarPtrs(framework::GradVarName("X"))[0], - context->IsRuntime()); - phi::KernelWithXShapeInferMeta(xshape, &dx); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } -}; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, {framework::GradVarName("Out"), @@ -431,17 +314,6 @@ REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, ops::FlattenGradInplaceInferer); -REGISTER_OPERATOR( - flatten_contiguous_range, - ops::FlattenContiguousRangeOp, - ops::FlattenContiguousRangeOpMaker, - ops::FlattenContiguousRangeGradOpMaker, - ops::FlattenContiguousRangeGradOpMaker, - ops::FlattenOpInplaceInferer); -REGISTER_OPERATOR(flatten_contiguous_range_grad, - ops::FlattenContiguousRangeGradOp, - ops::FlattenGradInplaceInferer); - REGISTER_OP_CPU_KERNEL(flatten, ops::FlattenKernel, ops::FlattenKernel, diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index f4608f00853..7bf3b5cd2fc 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -627,6 +627,18 @@ func : flash_attn_unpadded_grad data_type: q +- backward_op : flatten_grad + forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape) + args : (Tensor xshape, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : KernelWithXShapeInferMeta + param : [xshape] + kernel : + func : flatten_grad + data_type : out_grad + inplace : (out_grad -> x_grad) + - backward_op : flip_grad forward : flip (Tensor x, int[] axis) -> Tensor(out) args : (Tensor out_grad, int[] axis) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 181b819cde9..4ba99b1b813 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -397,20 +397,6 @@ func : fill_grad inplace : (out_grad -> x_grad) -- backward_op : flatten_grad - forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape) - args : (Tensor xshape, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : KernelWithXShapeInferMeta - param : [xshape] - kernel : - func : flatten_grad - data_type: out_grad - backend: out_grad - layout: out_grad - inplace : (out_grad -> x_grad) - - backward_op : fmax_grad forward : fmax(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 89aef2203cc..53ae099e762 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -508,19 +508,6 @@ inplace : (x -> out) backward: fill_grad -- op : flatten - args : (Tensor x, int start_axis, int stop_axis) - output : Tensor(out), Tensor(xshape) - infer_meta : - func : FlattenWithXShapeInferMeta - kernel : - func : flatten - backend : x - inplace : (x -> out) - view : (x -> out) - intermediate : xshape - backward : flatten_grad - - op : floor_divide args : (Tensor x, Tensor y) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 8e36affbb75..f807a3d748b 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -806,12 +806,16 @@ out : Out - op : flatten (flatten_contiguous_range) + backward : flatten_grad (flatten_contiguous_range_grad) inputs : x : X outputs : {out : Out, xshape : XShape} attrs : {start_axis : start_axis, stop_axis : stop_axis} + extra : + outputs : [xshape] + manual_signature : [flatten, flatten_grad] - op : flip inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8de00fc785c..3afbf00c049 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -660,6 +660,19 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad +- op : flatten + args : (Tensor x, int start_axis = 1, int stop_axis = 1) + output : Tensor(out), Tensor(xshape) + infer_meta : + func : FlattenWithXShapeInferMeta + kernel : + func : flatten + data_type : x + inplace : (x -> out) + view : (x -> out) + intermediate : xshape + backward : flatten_grad + - op : flip args : (Tensor x, int[] axis) output : Tensor (out) diff --git a/paddle/phi/ops/compat/flatten_sig.cc b/paddle/phi/ops/compat/flatten_sig.cc index b225dc62524..cd3ccd136de 100644 --- a/paddle/phi/ops/compat/flatten_sig.cc +++ b/paddle/phi/ops/compat/flatten_sig.cc @@ -17,6 +17,10 @@ limitations under the License. */ namespace phi { KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsForInferShape()) { + return KernelSignature( + "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); + } if (ctx.HasOutput("XShape")) { return KernelSignature( "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); -- GitLab