未验证 提交 18968e7e 编写于 作者: G gouzil 提交者: GitHub

[static op generation] triangular_solve (#53328)

* [static op generation] triangular_solve

* [phi] mv triangular_solve_grad to static_backward

* [phi] fix import

* [phi] mv to ops.yaml、 backward.yaml

* fix forward attr

* [phi] fix triangular_solve_grad args
上级 9ab14865
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class TriangularSolveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
class TriangularSolveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of triangular solve op, which "
"is the triangular coefficient matrix.");
AddInput("Y",
"(Tensor), The second input tensor of triangular solve op, which "
"is multiple right-hand.");
AddOutput("Out", "(Tensor), The solution tensor of triangular solve op.");
AddAttr<bool>("upper",
"whether to solve the upper-triangular or the "
"lower-triangular system of equations")
.SetDefault(true);
AddAttr<bool>("transpose", "whether X should be transposed firstly.")
.SetDefault(false);
AddAttr<bool>("unitriangular", "whether X is unit triangular.")
.SetDefault(false);
AddComment(R"DOC(
Triangular Solve Operator.
This operator is used to computes the solution of equations with a triangular coefficient matrix.
The equation is:
$$Out = X^-1 * Y$$
)DOC");
}
};
class TriangularSolveOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class TriangularSolveGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "triangular_solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "triangular_solve");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "triangular_solve");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"triangular_solve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("triangular_solve_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(triangular_solve,
TriangularSolveInferShapeFunctor,
PD_INFER_META(phi::TriangularSolveInferMeta));
REGISTER_OPERATOR(triangular_solve,
ops::TriangularSolveOp,
ops::TriangularSolveOpMaker,
ops::TriangularSolveOpInferVarType,
ops::TriangularSolveOpGradMaker<paddle::framework::OpDesc>,
ops::TriangularSolveOpGradMaker<paddle::imperative::OpBase>,
TriangularSolveInferShapeFunctor);
REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp);
......@@ -1950,6 +1950,16 @@
data_type : out_grad
no_need_buffer : x
- backward_op : triangular_solve_grad
forward : triangular_solve (Tensor x, Tensor y, bool upper=true, bool transpose=false, bool unitriangular=false) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool transpose, bool unitriangular)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : triangular_solve_grad
- backward_op : trilinear_interp_grad
forward : trilinear_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout="NCHW", int out_d=0, int out_h=0, int out_w=0, float[] scale={}, str interp_method="bilinear", bool align_corners=true, int align_mode=1) -> Tensor(output)
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, Tensor output_grad, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode)
......
......@@ -1034,16 +1034,6 @@
backward : transpose_double_grad
composite: transpose_grad(out_grad, perm, x_grad)
- backward_op : triangular_solve_grad
forward : triangular_solve (Tensor x, Tensor y, bool upper, bool tranpose, bool unitriangular) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool tranpose, bool unitriangular)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : triangular_solve_grad
- backward_op : tril_grad
forward : tril(Tensor x, int diagonal) -> Tensor(out)
args : (Tensor out_grad, int diagonal)
......
......@@ -1140,16 +1140,6 @@
func : transpose
backward : transpose_grad
- op : triangular_solve
args : (Tensor x, Tensor y, bool upper, bool transpose, bool unitriangular)
output : Tensor
infer_meta :
func : TriangularSolveInferMeta
kernel :
func : triangular_solve
data_type : x
backward : triangular_solve_grad
- op : tril
args : (Tensor x, int diagonal)
output : Tensor(out)
......
......@@ -2334,6 +2334,13 @@
outputs : [XShape]
attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", str mkldnn_data_type = "float32"]
- op : triangular_solve
backward : triangular_solve_grad
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : trilinear_interp (trilinear_interp_v2)
backward : trilinear_interp_grad (trilinear_interp_v2_grad)
inputs :
......
......@@ -2062,6 +2062,16 @@
func : trace
backward : trace_grad
- op : triangular_solve
args : (Tensor x, Tensor y, bool upper=true, bool transpose=false, bool unitriangular=false)
output : Tensor
infer_meta :
func : TriangularSolveInferMeta
kernel :
func : triangular_solve
data_type : x
backward : triangular_solve_grad
- op : trilinear_interp
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout="NCHW", int out_d=0, int out_h=0, int out_w=0, float[] scale={}, str interp_method="bilinear", bool align_corners=true, int align_mode=1)
output : Tensor(output)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TriangularSolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("triangular_solve_grad",
{"X", "Y", "Out", "Out@GRAD"},
{"upper", "transpose", "unitriangular"},
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(triangular_solve_grad,
phi::TriangularSolveGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册