未验证 提交 4f33f44b 编写于 作者: 张春乔 提交者: GitHub

[static op generation] lstsq (#53290)

上级 7f39bcd1
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class LstsqOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
// The output of lstsq is always complex-valued even for real-valued inputs
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (dtype != framework::proto::VarType::FP32 &&
dtype != framework::proto::VarType::FP64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported data type: %s!", dtype));
}
return phi::KernelKey(dtype, ctx.GetPlace());
}
};
class LstsqOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), A real-valued tensor with shape (*, m, n). "
"The accepted datatype is one of float32, float64");
AddInput("Y",
"(Tensor), A real-valued tensor with shape (*, m, k). "
"The accepted datatype is one of float32, float64");
AddAttr<float>(
"rcond",
"(float, default 0.0), A float value used to determine the effective "
"rank of A.")
.SetDefault(0.0f);
AddAttr<std::string>("driver",
"(string, default \"gels\"). "
"name of the LAPACK method to be used.")
.SetDefault("gels");
AddOutput("Solution",
"(Tensor), The output Solution tensor with shape (*, n, k).");
AddOutput("Residuals",
"(Tensor), The output Residuals tensor with shape (*, k).")
.AsDispensable();
AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*).");
AddOutput(
"SingularValues",
"(Tensor), The output SingularValues tensor with shape (*, min(m,n)).");
AddComment(R"DOC(
Lstsq Operator.
This API processes Lstsq functor for general matrices.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(lstsq,
LstsqInferShapeFunctor,
PD_INFER_META(phi::LstsqInferMeta));
REGISTER_OPERATOR(lstsq,
ops::LstsqOp,
ops::LstsqOpMaker,
LstsqInferShapeFunctor);
REGISTER_OP_VERSION(lstsq).AddCheckpoint(
R"ROC(
Upgrade lstsq, add 1 outputs [Residuals].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Residuals",
"Output tensor of lstsq operator, "
"meaning the squared residuals of the calculated solutions."));
......@@ -669,15 +669,6 @@
func : logsumexp
backward : logsumexp_grad
- op : lstsq
args : (Tensor x, Tensor y, Scalar rcond, str driver)
output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values)
infer_meta :
func : LstsqInferMeta
kernel :
func : lstsq
data_type : x
- op : matmul
args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
output : Tensor
......
......@@ -1415,6 +1415,16 @@
extra :
attrs : [bool use_mkldnn = false, bool is_test = false]
- op : lstsq
inputs :
{x : X, y : Y}
outputs :
{solution : Solution, residuals : Residuals, rank : Rank, singular_values : SingularValues}
scalar :
rcond :
data_type : float
support_tensor : true
- op : lu_unpack
backward : lu_unpack_grad
inputs :
......
......@@ -200,6 +200,13 @@
comment : In order to change output data type
default : 5
- op : lstsq
version :
- checkpoint : Upgrade lstsq, add 1 outputs [Residuals].
action :
- add_output : Residuals
comment : Output tensor of lstsq operator, meaning the squared residuals of the calculated solutions.
- op : matrix_nms
version :
- checkpoint : Upgrade matrix_nms, add a new output [RoisNum].
......
......@@ -1231,6 +1231,16 @@
func : logsigmoid
backward : logsigmoid_grad
- op : lstsq
args : (Tensor x, Tensor y, Scalar rcond=0.0f, str driver="gels")
output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values)
infer_meta :
func : LstsqInferMeta
kernel :
func : lstsq
data_type : x
optional : residuals
- op : lu
args : (Tensor x, bool pivot = true)
output : Tensor(out), Tensor(pivots), Tensor(infos)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LstsqOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lstsq",
{"X", "Y"},
{"rcond", "driver"},
{"Solution", "Residuals", "Rank", "SingularValues"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lstsq, phi::LstsqOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册