未验证 提交 93404a61 编写于 作者: C cyberslack_lee 提交者: GitHub

support auto generate for eigvalsh (#52687)

上级 0e776965
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class EigvalshOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class EigvalshOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), Hermitian or real symmetric matrices."
"Its shape should be [*, N, N] where * is zero or"
"more batch dimensions. The data type is float32 ,"
"float64, complex64, complex128.");
AddOutput("Eigenvalues",
"(Tensor), The eigenvalues in ascending order."
"The data type is float32 or float64.");
AddOutput(
"Eigenvectors",
"(Tensor), The column is the normalized eigenvector "
"corresponding to the eigenvalue. The data type is the same as ``X``."
"Eigenvectors are required to calculate gradient when backward.");
AddAttr<std::string>(
"UPLO",
"(string, default 'L'), 'L' represents the lower triangular matrix,"
"'U' represents the upper triangular matrix.")
.SetDefault("L");
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training.")
.SetDefault(false);
AddComment(R"DOC(
Eigvalsh Operator.
Computes the eigenvalues of a complex Hermitian
(conjugate symmetric) or a real symmetric matrix.
)DOC");
}
};
class EigvalshGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Eigenvectors"),
ctx.device_context().GetPlace());
}
};
template <typename T>
class EigvalshGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Eigenvectors", this->Output("Eigenvectors"));
op->SetInput(framework::GradVarName("Eigenvalues"),
this->OutputGrad("Eigenvalues"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eigvalsh,
EigvalshInferShapeFunctor,
PD_INFER_META(phi::EigvalshInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(eigvalsh_grad,
EigvalshGradInferShapeFunctor,
PD_INFER_META(phi::EigvalshGradInferMeta));
REGISTER_OPERATOR(eigvalsh,
ops::EigvalshOp,
ops::EigvalshOpMaker,
ops::EigvalshGradOpMaker<paddle::framework::OpDesc>,
ops::EigvalshGradOpMaker<paddle::imperative::OpBase>,
EigvalshInferShapeFunctor);
REGISTER_OPERATOR(eigvalsh_grad,
ops::EigvalshGradOp,
EigvalshGradInferShapeFunctor);
...@@ -456,6 +456,16 @@ ...@@ -456,6 +456,16 @@
func : eigh_grad func : eigh_grad
data_type : out_v data_type : out_v
- backward_op : eigvalsh_grad
forward : eigvalsh (Tensor x, str uplo = "L", bool is_test = false) -> Tensor(eigenvalues), Tensor(eigenvectors)
args : (Tensor eigenvectors, Tensor eigenvalues_grad, str uplo, bool is_test)
output : Tensor(x_grad)
infer_meta :
func : EigvalshGradInferMeta
kernel :
func : eigvalsh_grad
data_type : eigenvectors
- backward_op : elu_double_grad - backward_op : elu_double_grad
forward : elu_grad (Tensor x, Tensor out, Tensor grad_out, float alpha)-> Tensor(grad_x) forward : elu_grad (Tensor x, Tensor out, Tensor grad_out, float alpha)-> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, float alpha) args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, float alpha)
......
...@@ -340,18 +340,6 @@ ...@@ -340,18 +340,6 @@
kernel : kernel :
func : dropout_grad func : dropout_grad
- backward_op : eigvalsh_grad
forward : eigvalsh (Tensor x, str uplo, bool is_test) -> Tensor(eigenvalues), Tensor(eigenvectors)
args : (Tensor eigenvectors, Tensor eigenvalues_grad, str uplo, bool is_test)
output : Tensor(x_grad)
infer_meta :
func : EigvalshGradInferMeta
kernel :
func : eigvalsh_grad
data_type : eigenvectors
data_transform :
skip_transform : eigenvalues_grad
- backward_op : einsum_grad - backward_op : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
......
...@@ -424,15 +424,6 @@ ...@@ -424,15 +424,6 @@
data_type: DataType::FLOAT32 data_type: DataType::FLOAT32
optional : hypslength, refslength optional : hypslength, refslength
- op : eigvalsh
args : (Tensor x, str uplo, bool is_test)
output : Tensor(eigenvalues), Tensor(eigenvectors)
infer_meta :
func : EigvalshInferMeta
kernel :
func : eigvalsh
backward : eigvalsh_grad
- op : einsum - op : einsum
args : (Tensor[] x, str equation) args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
......
...@@ -614,6 +614,15 @@ ...@@ -614,6 +614,15 @@
outputs : outputs :
out : Out out : Out
- op : eigvalsh
backward : eigvalsh_grad
inputs :
{x : X}
outputs :
{eigenvalues : Eigenvalues, eigenvectors : Eigenvectors}
attrs :
uplo : UPLO
- op : elementwise_pow - op : elementwise_pow
backward : elementwise_pow_grad backward : elementwise_pow_grad
extra : extra :
......
...@@ -486,6 +486,16 @@ ...@@ -486,6 +486,16 @@
kernel : kernel :
func : eigvals func : eigvals
- op : eigvalsh
args : (Tensor x, str uplo = "L", bool is_test = false)
output : Tensor(eigenvalues), Tensor(eigenvectors)
infer_meta :
func : EigvalshInferMeta
kernel :
func : eigvalsh
data_type : x
backward : eigvalsh_grad
- op : elu - op : elu
args : (Tensor x, float alpha = 1.0f) args : (Tensor x, float alpha = 1.0f)
output : Tensor(out) output : Tensor(out)
......
...@@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eigvalsh_grad, ...@@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eigvalsh_grad,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
...@@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eigvalsh_grad, ...@@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eigvalsh_grad,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
...@@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eigvalsh, // cuda_only ...@@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eigvalsh, // cuda_only
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigvalshOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"eigvalsh", {"X"}, {"UPLO", "is_test"}, {"Eigenvalues", "Eigenvectors"});
}
KernelSignature EigvalshGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("eigvalsh_grad",
{"Eigenvectors", "Eigenvalues@GRAD"},
{"UPLO", "is_test"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigvalsh, phi::EigvalshOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(eigvalsh_grad, phi::EigvalshGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册