未验证 提交 ccb47076 编写于 作者: Z zyfncg 提交者: GitHub

Generate static graph code for some ops by yaml (part2) (#47752)

* generate static graph code of some op

* polish code

* fix bug

* update default value
上级 fd80288e
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class AsComplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class AsComplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of view_as_complex op.");
AddOutput("Out", "(Tensor), The output tensor of view_as_complex op.");
AddComment(R"DOC(
As_complex Operator.
This operator is used to return a complex tensor represented
by an old-fashioned real tensor. The size of the last dimension of
the input tensor should be 2, which corresponds to 'real' and
'complex', respectively.
)DOC");
}
};
template <typename T>
class AsComplexGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("as_real");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput("Out", this->InputGrad("X"));
}
};
class AsRealOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class AsRealOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of as_real op.");
AddOutput("Out", "(Tensor), The output tensor of as_real op.");
AddComment(R"DOC(
AsReal Operator.
This operator is used to return an old-fashioned real tensor from a
complex tensor. The size of the last dimension of the output tensor is 2,
which corresponds to 'real' and 'complex', respectively.
)DOC");
}
};
template <typename T>
class AsRealGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("as_complex");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput("Out", this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(as_real,
AsRealInferShapeFunctor,
PD_INFER_META(phi::AsRealInferMeta));
REGISTER_OPERATOR(as_real,
ops::AsRealOp,
ops::AsRealOpMaker,
AsRealInferShapeFunctor,
ops::AsRealGradMaker<paddle::framework::OpDesc>,
ops::AsRealGradMaker<paddle::imperative::OpBase>);
DECLARE_INFER_SHAPE_FUNCTOR(as_complex,
AsComplexInferShapeFunctor,
PD_INFER_META(phi::AsComplexInferMeta));
REGISTER_OPERATOR(as_complex,
ops::AsComplexOp,
ops::AsComplexOpMaker,
AsComplexInferShapeFunctor,
ops::AsComplexGradMaker<paddle::framework::OpDesc>,
ops::AsComplexGradMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput(
"X",
string::Sprintf("the left hand operand of %s operator", comment.type));
AddInput(
"Y",
string::Sprintf("the right hand operand of %s operator", comment.type));
AddOutput(
"Out",
string::Sprintf("tensor with a bool element. If all "
"element %s, the Out tensor is [True], else [False]",
comment.equation));
AddComment(string::Sprintf(R"DOC(
It operates element-wise on X and Y, and returns the Out. X, Y is a
N-dim tensor, which could be any type. If all element $%s$, the Out tensor
is [True], else [False]
)DOC",
comment.equation));
}
};
template <typename OpComment>
class CompareReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_REDUCE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
DECLARE_INFER_SHAPE_FUNCTOR(op_type, \
op_type##_InferShapeFunctor, \
PD_INFER_META(phi::CompareAllInferMeta)); \
REGISTER_OPERATOR( \
op_type, \
::paddle::operators::CompareReduceOp<_##op_type##Comment>, \
::paddle::operators::CompareReduceOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, \
op_type##_InferShapeFunctor);
REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y");
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class DiagEmbedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor. Must be at least 1-dimensional.");
AddOutput("Out", "A matrix whose certain 2D planes is diagonal matrix.");
AddAttr<int>(
"offset",
R"DOC((int, default 0), which diagonal to consider. Default: 0 (main diagonal).
)DOC")
.SetDefault(0);
AddAttr<int>(
"dim1",
R"DOC((int, default -2), first dimension with respect to which to take diagonal. Default: -2.
)DOC")
.SetDefault(-2);
AddAttr<int>(
"dim2",
R"DOC((int, default -1), second dimension with respect to which to take diagonal. Default: -1.
)DOC")
.SetDefault(-1);
AddComment(R"DOC(Creates a tensor whose diagonals of certain 2D planes
(specified by dim1 and dim2) are filled by input.
To facilitate creating batched diagonal matrices,
the 2D planes formed by the last two dimensions of the returned tensor
are chosen by default.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(diag_embed,
DiagEmbedInferShapeFunctor,
PD_INFER_META(phi::DiagEmbedInferMeta));
REGISTER_OPERATOR(
diag_embed,
ops::DiagEmbedOp,
ops::DiagEmbedOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DiagEmbedInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/eig_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class EigOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
// The output of eig is always complex-valued even for real-valued inputs
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (dtype != framework::proto::VarType::FP32 &&
dtype != framework::proto::VarType::FP64 &&
dtype != framework::proto::VarType::COMPLEX64 &&
dtype != framework::proto::VarType::COMPLEX128) {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported data type: %s!", dtype));
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
class EigOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor), A complex-valued or real-valued tensor with shape (*, "
"n, n). The accepted datatype is one of float32, float64, complex64 "
"or complex128");
AddOutput("Eigenvalues",
"(Tensor), The output eigenvalues tensor with shape (*, n). The "
"datatype is complex64 or complex128");
AddOutput("Eigenvectors",
"(Tensor), The output eigenvectors tensor with shape (*, n, n). "
"The datatype is complex64 or complex128");
AddComment(R"DOC(
Eig Operator.
This API processes eigen decomposition for general square matrices.
)DOC");
}
};
class EigGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Eigenvectors")),
ctx.device_context());
}
};
template <typename T>
class EigGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Eigenvalues", this->Output("Eigenvalues"));
op->SetInput("Eigenvectors", this->Output("Eigenvectors"));
op->SetInput(framework::GradVarName("Eigenvalues"),
this->OutputGrad("Eigenvalues"));
op->SetInput(framework::GradVarName("Eigenvectors"),
this->OutputGrad("Eigenvectors"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eig,
EigInferShapeFunctor,
PD_INFER_META(phi::EigInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(eig_grad,
EigGradInferShapeFunctor,
PD_INFER_META(phi::EigGradInferMeta));
REGISTER_OPERATOR(eig,
ops::EigOp,
ops::EigOpMaker,
ops::EigGradOpMaker<paddle::framework::OpDesc>,
ops::EigGradOpMaker<paddle::imperative::OpBase>,
EigInferShapeFunctor);
REGISTER_OPERATOR(eig_grad, ops::EigGradOp, EigGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#define EPSILON 1e-6
namespace paddle {
namespace operators {
inline int BatchCount(const phi::DenseTensor& matrix) {
int count = 1;
int num_dims = matrix.dims().size();
for (int i = 0; i < num_dims - 2; ++i) {
count *= matrix.dims()[i];
}
return count;
}
inline int MatrixStride(const phi::DenseTensor& matrix) {
framework::DDim dims_list = matrix.dims();
int num_dims = dims_list.size();
return dims_list[num_dims - 1] * dims_list[num_dims - 2];
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class EighOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class EighOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), Hermitian or real symmetric matrices."
"Its shape should be [*, N, N] where * is zero or"
"more batch dimensions. The data type is float32 ,"
"float64, complex64, complex128.");
AddOutput("Eigenvalues",
"(Tensor), The eigenvalues in ascending order."
"The data type is float32 or float64.");
AddOutput(
"Eigenvectors",
"(Tensor), The column is the normalized eigenvector "
"corresponding to the eigenvalue. The data type is the same as ``X``.");
AddAttr<std::string>(
"UPLO",
"(string, default 'L'), 'L' represents the lower triangular matrix,"
"'U' represents the upper triangular matrix.")
.SetDefault("L");
AddComment(R"DOC(
Eigh Operator.
Computes the eigenvalues and eigenvectors of a complex Hermitian
(conjugate symmetric) or a real symmetric matrix.
)DOC");
}
};
class EighGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", "EighGrad");
OP_INOUT_CHECK(
ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EighGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")),
"Input",
"Eigenvalues@GRAD",
"EighGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")),
"Input",
"Eigenvectors@GRAD",
"EighGrad");
auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Eigenvectors")),
ctx.device_context());
}
};
template <typename T>
class EighGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Eigenvalues", this->Output("Eigenvalues"));
op->SetInput("Eigenvectors", this->Output("Eigenvectors"));
op->SetInput(framework::GradVarName("Eigenvalues"),
this->OutputGrad("Eigenvalues"));
op->SetInput(framework::GradVarName("Eigenvectors"),
this->OutputGrad("Eigenvectors"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eigh,
EighInferShapeFunctor,
PD_INFER_META(phi::EighInferMeta));
REGISTER_OPERATOR(eigh,
ops::EighOp,
ops::EighOpMaker,
ops::EighGradOpMaker<paddle::framework::OpDesc>,
ops::EighGradOpMaker<paddle::imperative::OpBase>,
EighInferShapeFunctor);
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), A complex- or real-valued tensor with shape (*, n, n)"
"where * is zero or more batch dimensions");
AddOutput("Out",
"(Tensor) The output tensor with shape (*,n) cointaining the "
"eigenvalues of X.");
AddComment(R"DOC(eigvals operator
Return the eigenvalues of one or more square matrices. The eigenvalues are complex even when the input matrices are real.
)DOC");
}
};
class EigvalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eigvals,
EigvalsInferShapeFunctor,
PD_INFER_META(phi::EigvalsInferMeta));
REGISTER_OPERATOR(eigvals,
ops::EigvalsOp,
ops::EigvalsOpMaker,
EigvalsInferShapeFunctor);
......@@ -42,6 +42,18 @@
data_type : out_grad
no_need_buffer : x
- backward_op : as_complex_grad
forward : as_complex (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : as_real(out_grad)
- backward_op : as_real_grad
forward : as_real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : as_complex(out_grad)
- backward_op : asin_grad
forward : asin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -259,6 +271,27 @@
func : dot_grad
data_type : out_grad
- backward_op : eig_grad
forward : eig (Tensor x) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
output : Tensor(x_grad)
infer_meta :
func : EigGradInferMeta
kernel :
func : eig_grad
data_type : out_v
- backward_op : eigh_grad
forward : eigh (Tensor x, str UPLO) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_v]
kernel :
func : eigh_grad
data_type : out_v
- backward_op : elu_double_grad
forward : elu_grad (Tensor x, Tensor out, Tensor grad_out, float alpha)-> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, float alpha)
......
......@@ -98,18 +98,6 @@
kernel :
func : amin_grad
- backward_op : as_complex_grad
forward : as_complex (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : as_real(out_grad)
- backward_op : as_real_grad
forward : as_real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : as_complex(out_grad)
- backward_op : assign_grad
forward : assign (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......@@ -469,32 +457,6 @@
kernel :
func : dropout_grad
- backward_op : eig_grad
forward : eig (Tensor x) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_v]
kernel :
func : eig_grad
data_type : out_v
data_transform:
skip_transform : out_w, out_w_grad
- backward_op : eigh_grad
forward : eigh (Tensor x, str UPLO) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_v]
kernel :
func : eigh_grad
data_type : out_v
data_transform:
skip_transform : out_w, out_w_grad
- backward_op : eigvalsh_grad
forward : eigvalsh (Tensor x, str uplo, bool is_test) -> Tensor(eigenvalues), Tensor(eigenvectors)
args : (Tensor eigenvectors, Tensor eigenvalues_grad, str uplo, bool is_test)
......
......@@ -182,24 +182,6 @@
kernel :
func : arg_min
- op : as_complex
args : (Tensor x)
output : Tensor
infer_meta :
func : AsComplexInferMeta
kernel :
func : as_complex
backward : as_complex_grad
- op : as_real
args : (Tensor x)
output : Tensor
infer_meta :
func : AsRealInferMeta
kernel :
func : as_real
backward : as_real_grad
- op : assign
args : (Tensor x)
output : Tensor
......@@ -561,14 +543,6 @@
func : depthwise_conv2d_transpose
backward : depthwise_conv2d_transpose_grad
- op : diag_embed
args : (Tensor input, int offset, int dim1, int dim2)
output : Tensor(out)
infer_meta :
func : DiagEmbedInferMeta
kernel :
func : diag_embed
- op : distribute_fpn_proposals
args : (Tensor fpn_rois, Tensor rois_num, int min_level, int max_level, int refer_level, int refer_scale, bool pixel_offset)
output : Tensor[](multi_fpn_rois){max_level - min_level + 1}, Tensor[](multi_level_rois_num){max_level - min_level + 1}, Tensor(restore_index)
......@@ -609,23 +583,6 @@
data_type: DataType::FLOAT32
optional : hypslength, refslength
- op : eigh
args : (Tensor x, str UPLO)
output : Tensor(out_w), Tensor(out_v)
infer_meta :
func : EighInferMeta
kernel :
func : eigh
backward : eigh_grad
- op : eigvals
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : EigvalsInferMeta
kernel :
func : eigvals
- op : eigvalsh
args : (Tensor x, str uplo, bool is_test)
output : Tensor(eigenvalues), Tensor(eigenvectors)
......@@ -699,14 +656,6 @@
kernel :
func : equal
- op : equal_all
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : CompareAllInferMeta
kernel :
func : equal_all
- op : expand
args : (Tensor x, IntArray shape)
output : Tensor
......@@ -2545,15 +2494,6 @@
kernel:
func: dirichlet
- op: eig
args: (Tensor x)
output: Tensor(out_w), Tensor(out_v)
infer_meta:
func: EigInferMeta
kernel:
func: eig
backward: eig_grad
- op: fold
args: (Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
output: Tensor(out)
......
# All the configuration in this file are only for existing operators,
# which cannot be modified in principle. There's no need to configure
# this file for new operator.
#
# This file is used for two purposes:
# 1. Configure the mapping relationship of parameter names of operator
# between the operators in ops.yaml and the old operators defined
# in fluid.
# 2. Save the extra parameters in the OpMaker of operators temporarily,
# which will be removed in the future.
# - op : rnn
# backward : rnn_grad
# extra :
......@@ -59,6 +70,18 @@
out : Out
indices : Indices
- op : as_complex
inputs :
x : X
outputs :
out : Out
- op : as_real
inputs :
x : X
outputs :
out : Out
- op : asin
inputs :
x : X
......@@ -272,6 +295,12 @@
outputs :
out : Out
- op : diag_embed
inputs :
input : Input
outputs :
out : Out
- op : diagonal
inputs :
x : Input
......@@ -316,6 +345,26 @@
extra :
attrs : [bool fix_seed = false, int seed = 0]
- op : eig
inputs :
x : X
outputs :
out_w : Eigenvalues
out_v : Eigenvectors
- op : eigh
inputs :
x : X
outputs :
out_w : Eigenvalues
out_v : Eigenvectors
- op : eigvals
inputs :
x : X
outputs :
out : Out
- op : elementwise_pow
backward : elementwise_pow_grad
extra :
......@@ -338,6 +387,12 @@
int trainer_id = 0, int slot = 0, 'int64_t[] height_sections = {}', 'str[] epmap = {}',
'str[] table_names = {}']
- op : equal_all
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : erf
inputs :
x : X
......
......@@ -34,6 +34,24 @@
func : argsort
backward : argsort_grad
- op : as_complex
args : (Tensor x)
output : Tensor
infer_meta :
func : AsComplexInferMeta
kernel :
func : as_complex
backward : as_complex_grad
- op : as_real
args : (Tensor x)
output : Tensor
infer_meta :
func : AsRealInferMeta
kernel :
func : as_real
backward : as_real_grad
- op : asin
args : (Tensor x)
output : Tensor
......@@ -180,6 +198,14 @@
func : diag
backward : diag_grad
- op : diag_embed
args : (Tensor input, int offset = 0, int dim1 = -2, int dim2 = -1)
output : Tensor(out)
infer_meta :
func : DiagEmbedInferMeta
kernel :
func : diag_embed
- op : diagonal
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
......@@ -217,6 +243,32 @@
data_type : x
backward : dot_grad
- op : eig
args: (Tensor x)
output: Tensor(out_w), Tensor(out_v)
infer_meta:
func: EigInferMeta
kernel:
func: eig
backward: eig_grad
- op : eigh
args : (Tensor x, str UPLO = "L")
output : Tensor(out_w), Tensor(out_v)
infer_meta :
func : EighInferMeta
kernel :
func : eigh
backward : eigh_grad
- op : eigvals
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : EigvalsInferMeta
kernel :
func : eigvals
- op : elu
args : (Tensor x, float alpha = 1.0f)
output : Tensor(out)
......@@ -228,6 +280,14 @@
inplace : (x -> out)
backward : elu_grad
- op : equal_all
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : CompareAllInferMeta
kernel :
func : equal_all
- op : erf
args : (Tensor x)
output : Tensor
......
......@@ -47,4 +47,7 @@ PD_REGISTER_KERNEL(eig_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->InputAt(2).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -25,4 +25,7 @@ PD_REGISTER_KERNEL(eigh_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->InputAt(2).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -25,4 +25,7 @@ PD_REGISTER_KERNEL(eigh_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->InputAt(2).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -40,10 +40,6 @@ KernelSignature NotEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("not_equal", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature EqualAllArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("equal_all", {"X", "Y"}, {}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(less_than, phi::LessThanArgumentMapping);
......@@ -52,5 +48,3 @@ PD_REGISTER_ARG_MAPPING_FN(greater_than, phi::GreaterThanArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(greater_equal, phi::GreaterEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(equal, phi::EqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(not_equal, phi::NotEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(equal_all, phi::EqualAllArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"eig_grad",
{"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
{},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eig_grad, phi::EigGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"eigh_grad",
{"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
{},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigh_grad, phi::EighGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigvalsOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigvals", {"X"}, {}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigvals, phi::EigvalsOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册