未验证 提交 410e25fb 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for flatten (flatten_contiguous_range) (#52512)

* support auto generate for flatten (flatten_contiguous_range)

* add data_type for flatten_grad
上级 da0c7e14
...@@ -283,123 +283,6 @@ class Flatten2GradOp : public framework::OperatorWithKernel { ...@@ -283,123 +283,6 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
} }
}; };
class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange");
OP_INOUT_CHECK(
ctx->HasOutput("Out"), "Output", "Out", "FlattenContiguousRange");
const auto &start_axis = ctx->Attrs().Get<int>("start_axis");
const auto &stop_axis = ctx->Attrs().Get<int>("stop_axis");
// Construct MetaTensor for InferMeta Func
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
std::unique_ptr<CompatMetaTensor> xshape(nullptr);
if (ctx->HasOutput("XShape")) {
xshape = std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
ctx->GetOutputVarPtrs("XShape")[0], ctx->IsRuntime())));
}
phi::FlattenWithXShapeInferMeta(
x, start_axis, stop_axis, &out, xshape.get());
}
};
class FlattenContiguousRangeOpMaker : public FlattenOpMaker {
public:
void Make() override {
AddInput("X", "(Tensor) A tensor of rank >= axis.");
AddOutput("Out",
"A 2D tensor is reshaped input tensor. The input dimensions"
"up to axis are flattened to the outer dimension of the output"
"and the remaining input dimensions are flattened into the inner"
"dimension of the output.");
AddAttr<int>("start_axis",
"(int)"
"Indicate the input start dimension (exclusive) to flatten")
.SetDefault(1);
AddAttr<int>("stop_axis",
"(int)"
"Indicate the input stop dimension (exclusive) to flatten")
.SetDefault(1);
AddComment(R"DOC(
Flatten Operator
Flattens the input tensor into a new matrix according to start_axis and stop_axis.
Examples:
Case 1:
Given
X.shape = (3, 100, 100, 4)
and
start_axis = 2, stop_axis = -1
We get:
Out.shape = (3, 100, 400)
Case 2:
Given
X.shape = (3, 100, 100, 4)
and
start_axis = 0, stop_axis = -1
We get:
Out.shape = (3 * 100 * 100 * 4)
)DOC");
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate()
.AsExtra();
}
};
template <typename T>
class FlattenContiguousRangeGradOpMaker
: public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("flatten_contiguous_range_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("XShape"),
"Input",
"XShape",
"FlattenContiguousRangeGrad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"FlattenContiguousRangeGrad");
// Construct MetaTensor for InferMeta Func
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor xshape(context->GetInputVarPtrs("XShape")[0],
context->IsRuntime());
CompatMetaTensor dx(
context->GetOutputVarPtrs(framework::GradVarName("X"))[0],
context->IsRuntime());
phi::KernelWithXShapeInferMeta(xshape, &dx);
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
...@@ -431,17 +314,6 @@ REGISTER_OPERATOR(flatten2_grad, ...@@ -431,17 +314,6 @@ REGISTER_OPERATOR(flatten2_grad,
ops::Flatten2GradOp, ops::Flatten2GradOp,
ops::FlattenGradInplaceInferer); ops::FlattenGradInplaceInferer);
REGISTER_OPERATOR(
flatten_contiguous_range,
ops::FlattenContiguousRangeOp,
ops::FlattenContiguousRangeOpMaker,
ops::FlattenContiguousRangeGradOpMaker<paddle::framework::OpDesc>,
ops::FlattenContiguousRangeGradOpMaker<paddle::imperative::OpBase>,
ops::FlattenOpInplaceInferer);
REGISTER_OPERATOR(flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradOp,
ops::FlattenGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(flatten, REGISTER_OP_CPU_KERNEL(flatten,
ops::FlattenKernel<phi::CPUContext, float>, ops::FlattenKernel<phi::CPUContext, float>,
ops::FlattenKernel<phi::CPUContext, double>, ops::FlattenKernel<phi::CPUContext, double>,
......
...@@ -627,6 +627,18 @@ ...@@ -627,6 +627,18 @@
func : flash_attn_unpadded_grad func : flash_attn_unpadded_grad
data_type: q data_type: q
- backward_op : flatten_grad
forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : flatten_grad
data_type : out_grad
inplace : (out_grad -> x_grad)
- backward_op : flip_grad - backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out) forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis) args : (Tensor out_grad, int[] axis)
......
...@@ -397,20 +397,6 @@ ...@@ -397,20 +397,6 @@
func : fill_grad func : fill_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : flatten_grad
data_type: out_grad
backend: out_grad
layout: out_grad
inplace : (out_grad -> x_grad)
- backward_op : fmax_grad - backward_op : fmax_grad
forward : fmax(Tensor x, Tensor y) -> Tensor(out) forward : fmax(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad) args : (Tensor x, Tensor y, Tensor out_grad)
......
...@@ -508,19 +508,6 @@ ...@@ -508,19 +508,6 @@
inplace : (x -> out) inplace : (x -> out)
backward: fill_grad backward: fill_grad
- op : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : FlattenWithXShapeInferMeta
kernel :
func : flatten
backend : x
inplace : (x -> out)
view : (x -> out)
intermediate : xshape
backward : flatten_grad
- op : floor_divide - op : floor_divide
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
......
...@@ -806,12 +806,16 @@ ...@@ -806,12 +806,16 @@
out : Out out : Out
- op : flatten (flatten_contiguous_range) - op : flatten (flatten_contiguous_range)
backward : flatten_grad (flatten_contiguous_range_grad)
inputs : inputs :
x : X x : X
outputs : outputs :
{out : Out, xshape : XShape} {out : Out, xshape : XShape}
attrs : attrs :
{start_axis : start_axis, stop_axis : stop_axis} {start_axis : start_axis, stop_axis : stop_axis}
extra :
outputs : [xshape]
manual_signature : [flatten, flatten_grad]
- op : flip - op : flip
inputs : inputs :
......
...@@ -660,6 +660,19 @@ ...@@ -660,6 +660,19 @@
intermediate : softmax_lse, seed_offset intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad backward : flash_attn_unpadded_grad
- op : flatten
args : (Tensor x, int start_axis = 1, int stop_axis = 1)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : FlattenWithXShapeInferMeta
kernel :
func : flatten
data_type : x
inplace : (x -> out)
view : (x -> out)
intermediate : xshape
backward : flatten_grad
- op : flip - op : flip
args : (Tensor x, int[] axis) args : (Tensor x, int[] axis)
output : Tensor (out) output : Tensor (out)
......
...@@ -17,6 +17,10 @@ limitations under the License. */ ...@@ -17,6 +17,10 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsForInferShape()) {
return KernelSignature(
"flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"});
}
if (ctx.HasOutput("XShape")) { if (ctx.HasOutput("XShape")) {
return KernelSignature( return KernelSignature(
"flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册