未验证 提交 4af0f140 编写于 作者: G gouzil 提交者: GitHub

[static op generation] tril_triu (#54033)

* [phi] autogen code tril_triu

* [phi][api]fix tril_triu_grad args

* [fluid] clean cmake; [phi] fix infer_meta
上级 adca3654
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class TrilTriuOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Tensor, the input of tril_triu op");
AddOutput("Out",
"Tensor, the output tensor, with the same shape and data type as "
"input(x)");
AddAttr<int>("diagonal", "int number, the diagonal to consider.")
.SetDefault(0);
AddAttr<bool>("lower", "boolnumber, lower triangular or upper triangular.");
AddComment(R"DOC(
TrilTriu Operator.
The tril operator returns the lower triangular part of the matrix (2-D tensor)
or batch of matrices $input$. The lower triangular part of the matrix is defined
as the elements on and below the diagonal.
The triu operator returns the upper triangular part of a matrix (2-D tensor)
or batch of matrices $input$. The upper triangular part of the matrix is defined
as the elements on and above the diagonal.
The other elements of the result tensor out are set to 0.
The argument diagonal controls which diagonal to consider, default value is 0.
)DOC");
}
};
class TrilTriuGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
true,
platform::errors::NotFound(
"Input(Out@GRAD) of TrilTriuOp should not be null"));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")),
true,
platform::errors::NotFound(
"Output(X@Grad) of TrilTriuOp should not be null"));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
};
template <typename T>
class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tril_triu_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(tril_triu,
TrilTriuInferShapeFunctor,
PD_INFER_META(phi::TrilTriuInferMeta));
REGISTER_OPERATOR(tril_triu,
ops::TrilTriuOp,
ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>,
TrilTriuInferShapeFunctor);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
......@@ -288,7 +288,6 @@ register_unity_group(
transpose_op.cc
mkldnn/transpose_mkldnn_op.cc
tree_conv_op.cc
tril_triu_op.cc
unbind_op.cc
unfold_op.cc)
register_unity_group(
......
......@@ -2390,6 +2390,13 @@
outputs :
out : Out
- op : tril_triu
backward : tril_triu_grad
inputs :
{x: X}
outputs :
{out : Out}
- op : trilinear_interp (trilinear_interp_v2)
backward : trilinear_interp_grad (trilinear_interp_v2_grad)
inputs :
......
......@@ -65,6 +65,16 @@
func : softmax_grad
composite : softmax_grad(out, out_grad, axis, x_grad)
- backward_op : tril_triu_grad
forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out)
args : (Tensor out_grad, int diagonal, bool lower)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel:
func : tril_triu_grad
- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, str unpooling_type, int[] strides = {1,1}, int[] paddings ={0,0} ,IntArray output_size = {0,0}, str data_format="NCHW") -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] paddings, IntArray output_size, str data_format)
......
......@@ -356,6 +356,15 @@
param : [rows, cols, offset, dtype]
data_type : dtype
- op : tril_triu
args : (Tensor x, int diagonal = 0, bool lower = false)
output : Tensor(out)
infer_meta :
func : TrilTriuInferMeta
kernel :
func : tril_triu
backward : tril_triu_grad
- op : triu_indices
args : (int row = 0, int col = 0, int offset = 0, DataType dtype = DataType::INT64)
output : Tensor(out)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TrilTriuOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("tril_triu", {"X"}, {"diagonal", "lower"}, {"Out"});
}
KernelSignature TrilTriuGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"tril_triu_grad", {"Out@GRAD"}, {"diagonal", "lower"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(tril_triu, phi::TrilTriuOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tril_triu_grad, phi::TrilTriuGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册