未验证 提交 3d88816e 编写于 作者: L Lin Manhui 提交者: GitHub

[PHI] Move lu to phi (#44605)

* Add kernel declarations

* Copy kernel implementation code

* Transfer implementation code

* Register new kernels

* Remove old kernels

* Fix code style

* Fix bugs

* mutable_data->HostAlloc

* Transfer infermeta

* Add yaml and update python api

* Add PADDLE_WITH_HIP check

* Update unittests

* Fix bugs

* Fix bugs

* Optimize directory structure

* Add output checks

* lu_impl.h->lu_kernel_impl.h
Co-authored-by: NBobholamovic <linmanhui@baidu.com>
上级 88584396
......@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -39,39 +44,6 @@ class LUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LU");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "LU");
bool pivots = context->Attrs().Get<bool>("pivots");
auto x_dims = context->GetInputDim("X");
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_rank,
2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
context->SetOutputDim("Out", x_dims);
int m = x_dims[x_rank - 1];
int n = x_dims[x_rank - 2];
int min_mn = std::min(m, n);
auto dims_vec = phi::vectorize(x_dims);
OP_INOUT_CHECK(context->HasOutput("Infos"), "Output", "Infos", "LU");
if (x_rank == 2) {
auto Infos_dim = std::vector<int>(1);
context->SetOutputDim("Infos", phi::make_ddim(Infos_dim));
} else {
auto Infos_dim =
std::vector<int>(dims_vec.begin(), dims_vec.begin() + x_rank - 2);
context->SetOutputDim("Infos", phi::make_ddim(Infos_dim));
}
if (pivots) {
OP_INOUT_CHECK(context->HasOutput("Pivots"), "Output", "Pivots", "LU");
auto Pivots_dim =
std::vector<int>(dims_vec.begin(), dims_vec.begin() + x_rank - 1);
Pivots_dim[x_rank - 2] = min_mn;
context->SetOutputDim("Pivots", phi::make_ddim(Pivots_dim));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -99,57 +71,6 @@ class LUOpVarTypeInference : public framework::VarTypeInference {
}
};
template <typename T>
class LUKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto pivots = ctx.Attr<bool>("pivots");
auto *xin = ctx.Input<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *IpivT = ctx.Output<framework::Tensor>("Pivots");
auto *InfoT = ctx.Output<framework::Tensor>("Infos");
PADDLE_ENFORCE_EQ(pivots,
true,
platform::errors::InvalidArgument(
"lu without pivoting is not implemented on the CPU, "
"but got pivots=False"));
math::DeviceIndependenceTensorOperations<phi::CPUContext, T> helper(ctx);
*out = helper.Transpose(*xin);
auto outdims = out->dims();
auto outrank = outdims.size();
int m = static_cast<int>(outdims[outrank - 1]);
int n = static_cast<int>(outdims[outrank - 2]);
int lda = std::max(1, m);
auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1);
ipiv_dims[outrank - 2] = std::min(m, n);
IpivT->Resize(ipiv_dims);
auto ipiv_data = IpivT->mutable_data<int>(ctx.GetPlace());
auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2);
if (info_dims.size() == 0) {
info_dims = phi::make_ddim({1});
}
InfoT->Resize(info_dims);
auto info_data = InfoT->mutable_data<int>(ctx.GetPlace());
auto batchsize = product(info_dims);
batchsize = std::max(static_cast<int>(batchsize), 1);
auto out_data = out->mutable_data<T>(ctx.GetPlace());
for (int b = 0; b < batchsize; b++) {
auto out_data_item = &out_data[b * m * n];
int *info_data_item = &info_data[b];
int *ipiv_data_item = &ipiv_data[b * std::min(m, n)];
phi::funcs::lapackLu<T>(
m, n, out_data_item, lda, ipiv_data_item, info_data_item);
}
*out = helper.Transpose(*out);
}
};
template <typename T>
class LUOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -184,23 +105,6 @@ class LUGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "lu");
OP_INOUT_CHECK(ctx->HasInput("Pivots"), "Input", "Pivots", "lu");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"lu");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -219,19 +123,21 @@ DECLARE_INPLACE_OP_INFERER(LUGradOpInplaceInferer,
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(lu,
LUInferMetaFunctor,
PD_INFER_META(phi::LUInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(lu_grad,
LUGradInferMetaFunctor,
PD_INFER_META(phi::LUGradInferMeta));
REGISTER_OPERATOR(lu,
ops::LUOp,
ops::LUOpMaker,
ops::LUOpVarTypeInference,
ops::LUOpGradMaker<paddle::framework::OpDesc>,
ops::LUOpGradMaker<paddle::imperative::OpBase>,
ops::LUOpInplaceInferer);
LUInferMetaFunctor);
REGISTER_OPERATOR(lu_grad,
ops::LUGradOp,
ops::LUGradOpVarTypeInference,
ops::LUGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(lu, ops::LUKernel<float>, ops::LUKernel<double>);
REGISTER_OP_CPU_KERNEL(lu_grad,
ops::LUGradKernel<phi::CPUContext, float>,
ops::LUGradKernel<phi::CPUContext, double>);
LUGradInferMetaFunctor);
......@@ -524,305 +524,5 @@ void Unpack_Pivot(const DeviceContext& dev_ctx,
}
}
template <typename DeviceContext, typename T>
class LUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto xin = ctx.Input<framework::Tensor>("X");
auto out = ctx.Input<framework::Tensor>("Out");
auto P = ctx.Input<framework::Tensor>("Pivots");
auto dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto xdims = xin->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
framework::Tensor L, U, L_narrow, U_narrow, L_narrow_mH, U_narrow_mH,
grad_narrow;
LU_Unpack<DeviceContext, T>(dev_ctx, out, &L, &U);
Tensor_narrow<DeviceContext, T>(ctx, &L, &L_narrow, 0, k, 0, k);
Tensor_narrow<DeviceContext, T>(ctx, &U, &U_narrow, 0, k, 0, k);
Tensor_narrow<DeviceContext, T>(ctx, dout, &grad_narrow, 0, k, 0, k);
auto graddims = grad_narrow.dims();
Tensor_Conj<DeviceContext, T>(dev_ctx, L_narrow, &L_narrow_mH);
Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
L_narrow_mH = helper.Transpose(L_narrow_mH);
U_narrow_mH = helper.Transpose(U_narrow_mH);
auto LmHdims = L_narrow_mH.dims();
auto UmHdims = U_narrow_mH.dims();
framework::Tensor phi_L, phi_U, phi, psi;
phi_L.Resize(LmHdims);
phi_L.mutable_data<T>(ctx.GetPlace());
phi_U.Resize(UmHdims);
phi_U.mutable_data<T>(ctx.GetPlace());
auto mat_dim_l = phi::funcs::CreateMatrixDescriptor(LmHdims, 0, false);
auto mat_dim_u = phi::funcs::CreateMatrixDescriptor(UmHdims, 0, false);
auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(graddims, 0, false);
blas.MatMul(L_narrow_mH,
mat_dim_l,
grad_narrow,
mat_dim_g,
static_cast<T>(1),
&phi_L,
static_cast<T>(0));
blas.MatMul(grad_narrow,
mat_dim_g,
U_narrow_mH,
mat_dim_u,
static_cast<T>(1),
&phi_U,
static_cast<T>(0));
auto phil_rank = LmHdims.size();
auto phiu_rank = UmHdims.size();
platform::ForRange<DeviceContext> l_for_range(dev_ctx, phi_L.numel());
phi::funcs::TrilTriuCompute<T> tril_computer(phi_L.data<T>(),
-1,
true,
LmHdims[phil_rank - 2],
LmHdims[phil_rank - 1],
phi_L.data<T>());
l_for_range(tril_computer);
platform::ForRange<DeviceContext> u_for_range(dev_ctx, phi_U.numel());
phi::funcs::TrilTriuCompute<T> triu_computer(phi_U.data<T>(),
0,
false,
UmHdims[phiu_rank - 2],
UmHdims[phiu_rank - 1],
phi_U.data<T>());
u_for_range(triu_computer);
Tensor_Add<DeviceContext, T>(dev_ctx, phi_L, phi_U, &phi);
psi.Resize(xdims);
psi.mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> setter;
setter(dev_ctx, &psi, static_cast<T>(0));
std::vector<int64_t> axes = {xrank - 2, xrank - 1};
std::vector<int64_t> slice_starts(2, 0);
std::vector<int64_t> slice_ends(2, 0);
auto valuedims = vectorize(xdims);
framework::Tensor Pmat;
Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, &Pmat, m, k);
using Context =
typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
auto& phi_dev_ctx = static_cast<const Context&>(dev_ctx);
if (m <= n) {
if (k < n) {
framework::Tensor U_complement, U_grad_complement, phi_complement,
phi_complement_l;
Tensor_narrow<DeviceContext, T>(ctx, &U, &U_complement, 0, k, k, n);
Tensor_narrow<DeviceContext, T>(
ctx, dout, &U_grad_complement, 0, k, k, n);
framework::Tensor U_complement_mH = helper.Transpose(U_complement);
Tensor_Conj<DeviceContext, T>(
dev_ctx, U_complement_mH, &U_complement_mH);
auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(
U_grad_complement.dims(), 0, false);
auto mat_dim_u = phi::funcs::CreateMatrixDescriptor(
U_complement_mH.dims(), 0, false);
auto phidims = UmHdims;
phidims[UmHdims.size() - 2] = k;
phidims[UmHdims.size() - 1] = k;
phi_complement.Resize(phidims);
phi_complement.mutable_data<T>(ctx.GetPlace());
blas.MatMul(U_grad_complement,
mat_dim_g,
U_complement_mH,
mat_dim_u,
static_cast<T>(1),
&phi_complement,
static_cast<T>(0));
phi_complement_l.Resize(phidims);
phi_complement_l.mutable_data<T>(ctx.GetPlace());
const auto H = phidims[phidims.size() - 2];
const auto W = phidims[phidims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx,
phi_complement.numel());
phi::funcs::TrilTriuCompute<T> tril_computer(
phi_complement.data<T>(),
-1,
true,
H,
W,
phi_complement_l.data<T>());
x_for_range(tril_computer);
Tensor_Sub<DeviceContext, T>(dev_ctx, phi, phi_complement_l, &phi);
slice_starts[0] = 0;
slice_starts[1] = k;
slice_ends[0] = k;
slice_ends[1] = n;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = n - k;
SetValueCompute_dispatch<DeviceContext, T>(ctx,
&psi,
&U_grad_complement,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
}
framework::Tensor psi_principal, phi_mH, psi_tmp;
Tensor_Conj<DeviceContext, T>(dev_ctx, phi, &phi_mH);
phi_mH = helper.Transpose(phi_mH);
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, U_narrow, phi_mH, true, false, false, &psi_principal);
Tensor_Conj<DeviceContext, T>(dev_ctx, psi_principal, &psi_principal);
psi_principal = helper.Transpose(psi_principal);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx,
&psi,
&psi_principal,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, L_narrow_mH, psi, true, false, true, &psi_tmp);
auto mat_dim_p =
phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(psi_tmp.dims(), 0, false);
blas.MatMul(Pmat,
mat_dim_p,
psi_tmp,
mat_dim_b,
static_cast<T>(1),
dx,
static_cast<T>(0));
} else {
framework::Tensor L_complement, L_grad_complement, phi_complement,
phi_complement_u;
Tensor_narrow<DeviceContext, T>(ctx, &L, &L_complement, k, m, 0, k);
Tensor_narrow<DeviceContext, T>(
ctx, dout, &L_grad_complement, k, m, 0, k);
framework::Tensor L_complement_mH = helper.Transpose(L_complement);
Tensor_Conj<DeviceContext, T>(dev_ctx, L_complement_mH, &L_complement_mH);
auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(
L_grad_complement.dims(), 0, false);
auto mat_dim_u =
phi::funcs::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false);
auto phidims = LmHdims;
phidims[LmHdims.size() - 2] = k;
phidims[LmHdims.size() - 1] = k;
phi_complement.Resize(phidims);
phi_complement.mutable_data<T>(ctx.GetPlace());
blas.MatMul(L_complement_mH,
mat_dim_u,
L_grad_complement,
mat_dim_g,
static_cast<T>(1),
&phi_complement,
static_cast<T>(0));
phi_complement_u.Resize(phidims);
phi_complement_u.mutable_data<T>(ctx.GetPlace());
const auto H = phidims[phidims.size() - 2];
const auto W = phidims[phidims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx,
phi_complement.numel());
phi::funcs::TrilTriuCompute<T> triu_computer(
phi_complement.data<T>(), 0, false, H, W, phi_complement_u.data<T>());
x_for_range(triu_computer);
Tensor_Sub<DeviceContext, T>(dev_ctx, phi, phi_complement_u, &phi);
slice_starts[0] = k;
slice_starts[1] = 0;
slice_ends[0] = m;
slice_ends[1] = k;
valuedims[xrank - 2] = m - k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx,
&psi,
&L_grad_complement,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH;
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, L_narrow_mH, phi, true, false, true, &psi_principal);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx,
&psi,
&psi_principal,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
psi_tmp.Resize(psi.dims());
psi_tmp.mutable_data<T>(ctx.GetPlace());
auto mat_dim_p =
phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(psi.dims(), 0, false);
blas.MatMul(Pmat,
mat_dim_p,
psi,
mat_dim_b,
static_cast<T>(1),
&psi_tmp,
static_cast<T>(0));
psi_tmp = helper.Transpose(psi_tmp);
Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, U_narrow_mH, psi_tmp, true, false, false, &psi);
*dx = helper.Transpose(psi);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1425,6 +1425,15 @@
func : logsumexp
backward : logsumexp_grad
- api : lu
args : (Tensor x, bool pivot)
output : Tensor(out), Tensor(pivots), Tensor(infos)
infer_meta :
func : LUInferMeta
kernel :
func : lu
backward : lu_grad
# masked_select
- api : masked_select
args : (Tensor x, Tensor mask)
......
......@@ -1245,6 +1245,15 @@
kernel :
func : logsumexp_grad
- backward_api : lu_grad
forward : lu (Tensor x, bool pivot) -> Tensor(out), Tensor(pivots), Tensor(infos)
args : (Tensor x, Tensor out, Tensor pivots, Tensor out_grad, bool pivot)
output : Tensor(x_grad)
infer_meta :
func : LUGradInferMeta
kernel :
func : lu_grad
- backward_api : masked_select_grad
forward : masked_select (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor x, Tensor mask, Tensor out_grad)
......
......@@ -442,6 +442,20 @@ void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) {
dx->share_lod(xshape);
}
void LUGradInferMeta(const MetaTensor& x,
const MetaTensor& out,
const MetaTensor& pivots,
const MetaTensor& out_grad,
bool pivot,
MetaTensor* x_grad) {
auto x_dims = x.dims();
if (x_grad) {
x_grad->set_dims(x_dims);
x_grad->set_dtype(x.dtype());
}
}
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -200,6 +200,13 @@ void InverseGradInferMeta(const MetaTensor& out,
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx);
void LUGradInferMeta(const MetaTensor& x,
const MetaTensor& out,
const MetaTensor& pivots,
const MetaTensor& out_grad,
bool pivot,
MetaTensor* x_grad);
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -1379,15 +1379,59 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
out->set_dtype(x.dtype());
}
void LUInferMeta(const MetaTensor& x,
bool pivot,
MetaTensor* out,
MetaTensor* pivots,
MetaTensor* infos) {
auto x_dims = x.dims();
int x_rank = x_dims.size();
PADDLE_ENFORCE_NOT_NULL(
out, phi::errors::InvalidArgument("Output(Out) should not be nullptr."));
PADDLE_ENFORCE_GE(
x_rank,
2,
phi::errors::InvalidArgument("The rank of input must greater than 2."));
out->set_dims(x_dims);
out->set_dtype(x.dtype());
int m = x_dims[x_rank - 1];
int n = x_dims[x_rank - 2];
int min_mn = std::min(m, n);
auto dims_vec = phi::vectorize(x_dims);
PADDLE_ENFORCE_NOT_NULL(
infos,
phi::errors::InvalidArgument("Output(Infos) should not be nullptr."));
if (x_rank == 2) {
auto Infos_dim = std::vector<int>(1);
infos->set_dims(phi::make_ddim(Infos_dim));
} else {
auto Infos_dim =
std::vector<int>(dims_vec.begin(), dims_vec.begin() + x_rank - 2);
infos->set_dims(phi::make_ddim(Infos_dim));
}
infos->set_dtype(DataType::INT32);
if (pivot) {
PADDLE_ENFORCE_NOT_NULL(
pivots,
phi::errors::InvalidArgument("Output(Pivots) should not be nullptr."));
auto Pivots_dim =
std::vector<int>(dims_vec.begin(), dims_vec.begin() + x_rank - 1);
Pivots_dim[x_rank - 2] = min_mn;
pivots->set_dims(phi::make_ddim(Pivots_dim));
pivots->set_dtype(DataType::INT32);
}
}
void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));
PADDLE_ENFORCE_GE(dim_x.size(),
2,
phi::errors::InvalidArgument(
"The dims of input must be greater than 2."));
if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
......@@ -1418,11 +1462,11 @@ void MaxOutInferMeta(const MetaTensor& x,
axis == 1 || axis == -1 || axis == 3,
true,
phi::errors::InvalidArgument(
"axis only supported 1, -1 or 3, but recevied axis is: %d", axis));
"axis only supported 1, -1 or 3, but recevied axis is: %d.", axis));
PADDLE_ENFORCE_EQ(in_x_dims.size(),
4,
phi::errors::InvalidArgument(
"x's dims should be 4, but received x's dims is: %d",
"x's dims should be 4, but received x's dims is: %d.",
in_x_dims.size()));
if (axis < 0) {
......
......@@ -181,6 +181,12 @@ void LogsumexpInferMeta(const MetaTensor& input,
bool reduce_all,
MetaTensor* out);
void LUInferMeta(const MetaTensor& x,
bool pivot,
MetaTensor* out,
MetaTensor* pivots,
MetaTensor* infos);
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
void MatrixRankInferMeta(const MetaTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lu_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_grad_kernel.h"
PD_REGISTER_KERNEL(lu_grad, CPU, ALL_LAYOUT, phi::LUGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
#include "paddle/phi/kernels/lu_kernel.h"
namespace phi {
template <typename T, typename Context>
void LUKernel(const Context& dev_ctx,
const DenseTensor& x,
bool pivot,
DenseTensor* out,
DenseTensor* pivots,
DenseTensor* infos) {
PADDLE_ENFORCE_EQ(pivot,
true,
errors::InvalidArgument(
"lu without pivoting is not implemented on the CPU, "
"but got pivots=False"));
*out = Transpose2DTo6D<Context, T>(dev_ctx, x);
auto outdims = out->dims();
auto outrank = outdims.size();
int m = static_cast<int>(outdims[outrank - 1]);
int n = static_cast<int>(outdims[outrank - 2]);
int lda = std::max(1, m);
auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1);
ipiv_dims[outrank - 2] = std::min(m, n);
pivots->Resize(ipiv_dims);
dev_ctx.template Alloc<int>(pivots);
auto ipiv_data = pivots->data<int>();
auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2);
if (info_dims.size() == 0) {
info_dims = phi::make_ddim({1});
}
infos->Resize(info_dims);
dev_ctx.template Alloc<int>(infos);
auto info_data = infos->data<int>();
auto batchsize = product(info_dims);
batchsize = std::max(static_cast<int>(batchsize), 1);
dev_ctx.template Alloc<T>(out);
auto out_data = out->data<T>();
for (int b = 0; b < batchsize; b++) {
auto out_data_item = &out_data[b * m * n];
int* info_data_item = &info_data[b];
int* ipiv_data_item = &ipiv_data[b * std::min(m, n)];
phi::funcs::lapackLu<T>(
m, n, out_data_item, lda, ipiv_data_item, info_data_item);
}
*out = Transpose2DTo6D<Context, T>(dev_ctx, *out);
}
} // namespace phi
PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lu_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_grad_kernel.h"
PD_REGISTER_KERNEL(lu_grad, GPU, ALL_LAYOUT, phi::LUGradKernel, float, double) {
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace operators {
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
#include "paddle/phi/kernels/lu_kernel.h"
using Tensor = framework::Tensor;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
namespace phi {
template <typename T>
void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH,
......@@ -49,8 +49,8 @@ void cusolver_bufferSize<float>(const cusolverDnHandle_t& cusolverH,
float* d_A,
int lda,
int* lwork) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgetrf_bufferSize(
cusolverH, m, n, d_A, lda, lwork));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnSgetrf_bufferSize(cusolverH, m, n, d_A, lda, lwork));
}
template <>
......@@ -60,8 +60,8 @@ void cusolver_bufferSize<double>(const cusolverDnHandle_t& cusolverH,
double* d_A,
int lda,
int* lwork) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgetrf_bufferSize(
cusolverH, m, n, d_A, lda, lwork));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnDgetrf_bufferSize(cusolverH, m, n, d_A, lda, lwork));
}
template <>
......@@ -73,7 +73,7 @@ void cusolver_getrf<float>(const cusolverDnHandle_t& cusolverH,
float* d_work,
int* d_Ipiv,
int* d_info) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgetrf(
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSgetrf(
cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info));
}
......@@ -86,27 +86,26 @@ void cusolver_getrf<double>(const cusolverDnHandle_t& cusolverH,
double* d_work,
int* d_Ipiv,
int* d_info) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgetrf(
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrf(
cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info));
}
template <typename T>
void lu_decomposed_kernel(int m,
template <typename T, typename Context>
void lu_decomposed_kernel(const Context& dev_ctx,
int m,
int n,
T* d_A,
int lda,
int* d_Ipiv,
int* d_info,
const framework::ExecutionContext& ctx) {
int* d_info) {
/* step 1: get cusolver handle*/
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto cusolverH = dev_ctx.cusolver_dn_handle();
/* step 2: query working space of getrf */
int lwork;
cusolver_bufferSize(cusolverH, m, n, d_A, lda, &lwork);
auto work_buff = memory::Alloc(dev_ctx, lwork * sizeof(T));
auto work_buff = paddle::memory::Alloc(dev_ctx, lwork * sizeof(T));
T* d_work = reinterpret_cast<T*>(work_buff->ptr());
/* step 3: LU factorization */
......@@ -118,77 +117,69 @@ void lu_decomposed_kernel(int m,
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
}
template <typename T>
class LUCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
template <typename T, typename Context>
void LUKernel(const Context& dev_ctx,
const DenseTensor& x,
bool pivot,
DenseTensor* out,
DenseTensor* pivots,
DenseTensor* infos) {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
const int64_t kMaxBlockDim = 512;
#endif
auto* xin = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* IpivT = ctx.Output<framework::Tensor>("Pivots");
auto* InfoT = ctx.Output<framework::Tensor>("Infos");
auto pivots = ctx.Attr<bool>("pivots");
math::DeviceIndependenceTensorOperations<
paddle::platform::CUDADeviceContext,
T>
helper(ctx);
*out = helper.Transpose(*xin);
auto outdims = out->dims();
auto outrank = outdims.size();
int m = static_cast<int>(outdims[outrank - 1]);
int n = static_cast<int>(outdims[outrank - 2]);
int lda = std::max(1, m);
if (pivots) {
auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1);
ipiv_dims[outrank - 2] = std::min(m, n);
IpivT->Resize(ipiv_dims);
}
auto ipiv_data = IpivT->mutable_data<int>(ctx.GetPlace());
auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2);
if (info_dims.size() == 0) {
info_dims = phi::make_ddim({1});
}
InfoT->Resize(info_dims);
auto info_data = InfoT->mutable_data<int>(ctx.GetPlace());
auto batchsize = product(info_dims);
batchsize = std::max(static_cast<int>(batchsize), 1);
auto out_data = out->mutable_data<T>(ctx.GetPlace());
for (int b = 0; b < batchsize; b++) {
auto out_data_item = &out_data[b * m * n];
int* info_data_item = &info_data[b];
if (pivots) {
auto ipiv_data_item = &ipiv_data[b * std::min(m, n)];
lu_decomposed_kernel(
m, n, out_data_item, lda, ipiv_data_item, info_data_item, ctx);
} else {
lu_decomposed_kernel(
m, n, out_data_item, lda, NULL, info_data_item, ctx);
}
}
*out = helper.Transpose(*out);
*out = Transpose2DTo6D<Context, T>(dev_ctx, x);
auto outdims = out->dims();
auto outrank = outdims.size();
int m = static_cast<int>(outdims[outrank - 1]);
int n = static_cast<int>(outdims[outrank - 2]);
int lda = std::max(1, m);
if (pivot) {
auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1);
ipiv_dims[outrank - 2] = std::min(m, n);
pivots->Resize(ipiv_dims);
}
};
dev_ctx.template Alloc<int>(pivots);
auto ipiv_data = pivots->data<int>();
} // namespace operators
} // namespace paddle
auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2);
if (info_dims.size() == 0) {
info_dims = phi::make_ddim({1});
}
infos->Resize(info_dims);
dev_ctx.template Alloc<int>(infos);
auto info_data = infos->data<int>();
auto batchsize = product(info_dims);
batchsize = std::max(static_cast<int>(batchsize), 1);
dev_ctx.template Alloc<T>(out);
auto out_data = out->data<T>();
for (int b = 0; b < batchsize; b++) {
auto out_data_item = &out_data[b * m * n];
int* info_data_item = &info_data[b];
if (pivot) {
auto ipiv_data_item = &ipiv_data[b * std::min(m, n)];
lu_decomposed_kernel(
dev_ctx, m, n, out_data_item, lda, ipiv_data_item, info_data_item);
} else {
lu_decomposed_kernel(
dev_ctx, m, n, out_data_item, lda, NULL, info_data_item);
}
}
*out = Transpose2DTo6D<Context, T>(dev_ctx, *out);
}
namespace ops = paddle::operators;
namespace plat = paddle::platform;
} // namespace phi
REGISTER_OP_CUDA_KERNEL(lu,
ops::LUCUDAKernel<float>,
ops::LUCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(lu_grad,
ops::LUGradKernel<plat::CUDADeviceContext, float>,
ops::LUGradKernel<plat::CUDADeviceContext, double>);
PD_REGISTER_KERNEL(lu, // cuda_only
GPU,
ALL_LAYOUT,
phi::LUKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void LUGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& pivots,
const DenseTensor& out_grad,
bool pivot,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto xdims = x.dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
DenseTensor L, U, L_narrow, U_narrow, L_narrow_mH, U_narrow_mH, grad_narrow;
LU_Unpack<Context, T>(dev_ctx, &out, &L, &U);
Tensor_narrow<Context, T>(dev_ctx, &L, &L_narrow, 0, k, 0, k);
Tensor_narrow<Context, T>(dev_ctx, &U, &U_narrow, 0, k, 0, k);
Tensor_narrow<Context, T>(dev_ctx, &out_grad, &grad_narrow, 0, k, 0, k);
auto graddims = grad_narrow.dims();
Tensor_Conj<Context, T>(dev_ctx, L_narrow, &L_narrow_mH);
Tensor_Conj<Context, T>(dev_ctx, U_narrow, &U_narrow_mH);
L_narrow_mH = Transpose2DTo6D<Context, T>(dev_ctx, L_narrow_mH);
U_narrow_mH = Transpose2DTo6D<Context, T>(dev_ctx, U_narrow_mH);
auto LmHdims = L_narrow_mH.dims();
auto UmHdims = U_narrow_mH.dims();
DenseTensor phi_L, phi_U, phi, psi;
phi_L.Resize(LmHdims);
dev_ctx.template Alloc<T>(&phi_L);
phi_U.Resize(UmHdims);
dev_ctx.template Alloc<T>(&phi_U);
auto mat_dim_l = phi::funcs::CreateMatrixDescriptor(LmHdims, 0, false);
auto mat_dim_u = phi::funcs::CreateMatrixDescriptor(UmHdims, 0, false);
auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(graddims, 0, false);
blas.MatMul(L_narrow_mH,
mat_dim_l,
grad_narrow,
mat_dim_g,
static_cast<T>(1),
&phi_L,
static_cast<T>(0));
blas.MatMul(grad_narrow,
mat_dim_g,
U_narrow_mH,
mat_dim_u,
static_cast<T>(1),
&phi_U,
static_cast<T>(0));
auto phil_rank = LmHdims.size();
auto phiu_rank = UmHdims.size();
phi::funcs::ForRange<Context> l_for_range(dev_ctx, phi_L.numel());
phi::funcs::TrilTriuCompute<T> tril_computer(phi_L.data<T>(),
-1,
true,
LmHdims[phil_rank - 2],
LmHdims[phil_rank - 1],
phi_L.data<T>());
l_for_range(tril_computer);
phi::funcs::ForRange<Context> u_for_range(dev_ctx, phi_U.numel());
phi::funcs::TrilTriuCompute<T> triu_computer(phi_U.data<T>(),
0,
false,
UmHdims[phiu_rank - 2],
UmHdims[phiu_rank - 1],
phi_U.data<T>());
u_for_range(triu_computer);
Tensor_Add<Context, T>(dev_ctx, phi_L, phi_U, &phi);
psi.Resize(xdims);
dev_ctx.template Alloc<T>(&psi);
phi::funcs::SetConstant<Context, T> setter;
setter(dev_ctx, &psi, static_cast<T>(0));
std::vector<int64_t> axes = {xrank - 2, xrank - 1};
std::vector<int64_t> slice_starts(2, 0);
std::vector<int64_t> slice_ends(2, 0);
auto valuedims = vectorize(xdims);
DenseTensor Pmat;
Unpack_Pivot<Context, T>(dev_ctx, pivots, &Pmat, m, k);
if (m <= n) {
if (k < n) {
DenseTensor U_complement, U_grad_complement, phi_complement,
phi_complement_l;
Tensor_narrow<Context, T>(dev_ctx, &U, &U_complement, 0, k, k, n);
Tensor_narrow<Context, T>(
dev_ctx, &out_grad, &U_grad_complement, 0, k, k, n);
DenseTensor U_complement_mH =
Transpose2DTo6D<Context, T>(dev_ctx, U_complement);
Tensor_Conj<Context, T>(dev_ctx, U_complement_mH, &U_complement_mH);
auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(
U_grad_complement.dims(), 0, false);
auto mat_dim_u =
phi::funcs::CreateMatrixDescriptor(U_complement_mH.dims(), 0, false);
auto phidims = UmHdims;
phidims[UmHdims.size() - 2] = k;
phidims[UmHdims.size() - 1] = k;
phi_complement.Resize(phidims);
dev_ctx.template Alloc<T>(&phi_complement);
blas.MatMul(U_grad_complement,
mat_dim_g,
U_complement_mH,
mat_dim_u,
static_cast<T>(1),
&phi_complement,
static_cast<T>(0));
phi_complement_l.Resize(phidims);
dev_ctx.template Alloc<T>(&phi_complement_l);
const auto H = phidims[phidims.size() - 2];
const auto W = phidims[phidims.size() - 1];
phi::funcs::ForRange<Context> x_for_range(dev_ctx,
phi_complement.numel());
phi::funcs::TrilTriuCompute<T> tril_computer(
phi_complement.data<T>(), -1, true, H, W, phi_complement_l.data<T>());
x_for_range(tril_computer);
Tensor_Sub<Context, T>(dev_ctx, phi, phi_complement_l, &phi);
slice_starts[0] = 0;
slice_starts[1] = k;
slice_ends[0] = k;
slice_ends[1] = n;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = n - k;
SetValueCompute_dispatch<Context, T>(dev_ctx,
&psi,
&U_grad_complement,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
}
DenseTensor psi_principal, phi_mH, psi_tmp;
Tensor_Conj<Context, T>(dev_ctx, phi, &phi_mH);
phi_mH = Transpose2DTo6D<Context, T>(dev_ctx, phi_mH);
phi::TriangularSolveKernel<T, Context>(
dev_ctx, U_narrow, phi_mH, true, false, false, &psi_principal);
Tensor_Conj<Context, T>(dev_ctx, psi_principal, &psi_principal);
psi_principal = Transpose2DTo6D<Context, T>(dev_ctx, psi_principal);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<Context, T>(dev_ctx,
&psi,
&psi_principal,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
phi::TriangularSolveKernel<T, Context>(
dev_ctx, L_narrow_mH, psi, true, false, true, &psi_tmp);
auto mat_dim_p = phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(psi_tmp.dims(), 0, false);
blas.MatMul(Pmat,
mat_dim_p,
psi_tmp,
mat_dim_b,
static_cast<T>(1),
x_grad,
static_cast<T>(0));
} else {
DenseTensor L_complement, L_grad_complement, phi_complement,
phi_complement_u;
Tensor_narrow<Context, T>(dev_ctx, &L, &L_complement, k, m, 0, k);
Tensor_narrow<Context, T>(
dev_ctx, &out_grad, &L_grad_complement, k, m, 0, k);
DenseTensor L_complement_mH =
Transpose2DTo6D<Context, T>(dev_ctx, L_complement);
Tensor_Conj<Context, T>(dev_ctx, L_complement_mH, &L_complement_mH);
auto mat_dim_g =
phi::funcs::CreateMatrixDescriptor(L_grad_complement.dims(), 0, false);
auto mat_dim_u =
phi::funcs::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false);
auto phidims = LmHdims;
phidims[LmHdims.size() - 2] = k;
phidims[LmHdims.size() - 1] = k;
phi_complement.Resize(phidims);
dev_ctx.template Alloc<T>(&phi_complement);
blas.MatMul(L_complement_mH,
mat_dim_u,
L_grad_complement,
mat_dim_g,
static_cast<T>(1),
&phi_complement,
static_cast<T>(0));
phi_complement_u.Resize(phidims);
dev_ctx.template Alloc<T>(&phi_complement_u);
const auto H = phidims[phidims.size() - 2];
const auto W = phidims[phidims.size() - 1];
phi::funcs::ForRange<Context> x_for_range(dev_ctx, phi_complement.numel());
phi::funcs::TrilTriuCompute<T> triu_computer(
phi_complement.data<T>(), 0, false, H, W, phi_complement_u.data<T>());
x_for_range(triu_computer);
Tensor_Sub<Context, T>(dev_ctx, phi, phi_complement_u, &phi);
slice_starts[0] = k;
slice_starts[1] = 0;
slice_ends[0] = m;
slice_ends[1] = k;
valuedims[xrank - 2] = m - k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<Context, T>(dev_ctx,
&psi,
&L_grad_complement,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
DenseTensor psi_principal, phi_mH, psi_tmp, U_narrow_mH;
phi::TriangularSolveKernel<T, Context>(
dev_ctx, L_narrow_mH, phi, true, false, true, &psi_principal);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<Context, T>(dev_ctx,
&psi,
&psi_principal,
&psi,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
psi_tmp.Resize(psi.dims());
dev_ctx.template Alloc<T>(&psi_tmp);
auto mat_dim_p = phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(psi.dims(), 0, false);
blas.MatMul(Pmat,
mat_dim_p,
psi,
mat_dim_b,
static_cast<T>(1),
&psi_tmp,
static_cast<T>(0));
psi_tmp = Transpose2DTo6D<Context, T>(dev_ctx, psi_tmp);
Tensor_Conj<Context, T>(dev_ctx, U_narrow, &U_narrow_mH);
phi::TriangularSolveKernel<T, Context>(
dev_ctx, U_narrow_mH, psi_tmp, true, false, false, &psi);
*x_grad = Transpose2DTo6D<Context, T>(dev_ctx, psi);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
namespace phi {
template <typename T>
using SubFunctor = phi::funcs::SubtractFunctor<T>;
template <typename Context, typename T, size_t D>
void SetValueCompute(const Context& dev_ctx,
DenseTensor* in,
DenseTensor* value_tensor,
DenseTensor* out,
const std::vector<int64_t>& axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
const std::vector<int64_t>& shape) {
std::vector<int64_t> steps = {1, 1};
std::vector<int64_t> decrease_axes = {};
std::vector<int64_t> none_axes = {};
auto dtype = in->dtype();
auto in_dims = in->dims();
phi::funcs::CheckAndUpdateSliceAttrs<int64_t>(
in_dims, axes, starts, ends, &steps);
auto slice_dims =
phi::funcs::GetSliceDims(in_dims, axes, *starts, *ends, &steps);
auto decrease_slice_dims =
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}
auto place = dev_ctx.GetPlace();
auto& eigen_place = *dev_ctx.eigen_device();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
phi::Copy(dev_ctx, *in, place, false, out);
DenseTensor slice_tensor(dtype), pad_tensor(dtype);
slice_tensor.Resize(slice_dims);
dev_ctx.template Alloc<T>(&slice_tensor);
pad_tensor.Resize(in_dims);
dev_ctx.template Alloc<T>(&pad_tensor);
auto pad_e = EigenTensor<T, D>::From(pad_tensor, in_dims);
auto out_e = EigenTensor<T, D>::From(*out);
auto slice_e = EigenTensor<T, D>::From(slice_tensor, slice_dims);
// Step 1: Set the value of out at `_index` to zero
slice_e.device(eigen_place) = slice_e.constant(T(0));
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = (*starts)[i];
ends_indices[axis_index] = (*ends)[i];
strides_indices[axis_index] = steps[i];
if ((*starts)[i] ==
(*ends)[i]) { // slice is empty, data will not be changed
return;
}
}
out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
phi::funcs::ElementwiseCompute<SubFunctor<T>, T, T>(dev_ctx,
slice_tensor,
*value_tensor,
-1,
SubFunctor<T>(),
&slice_tensor);
} else {
DenseTensor value_t(dtype);
auto value_dims = phi::make_ddim(shape);
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
value_t.Resize(value_dims);
dev_ctx.template Alloc<T>(&value_t);
phi::funcs::ElementwiseCompute<SubFunctor<T>, T, T>(
dev_ctx, slice_tensor, value_t, -1, SubFunctor<T>(), &slice_tensor);
}
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = pad_e.constant(T(0));
pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 3: Set out tensor with value_tensor
out_e.device(eigen_place) = out_e - pad_e;
}
template <typename Context, typename T>
void SetValueCompute_dispatch(const Context& dev_ctx,
DenseTensor* in,
DenseTensor* value_tensor,
DenseTensor* out,
const std::vector<int64_t>& axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
const std::vector<int64_t>& shape,
int rank) {
switch (rank) {
case 1:
SetValueCompute<Context, T, 1>(
dev_ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 2:
SetValueCompute<Context, T, 2>(
dev_ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 3:
SetValueCompute<Context, T, 3>(
dev_ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 4:
SetValueCompute<Context, T, 4>(
dev_ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 5:
SetValueCompute<Context, T, 5>(
dev_ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 6:
SetValueCompute<Context, T, 6>(
dev_ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
template <typename Context, typename T>
void Tensor_Conj(const Context& dev_ctx,
const DenseTensor& tensor,
DenseTensor* out) {
out->Resize(tensor.dims());
phi::funcs::ForRange<Context> out_for_range(dev_ctx, tensor.numel());
dev_ctx.template Alloc<T>(out);
phi::funcs::ConjFunctor<T> out_functor(
tensor.data<T>(), tensor.numel(), out->data<T>());
out_for_range(out_functor);
}
template <typename Context, typename T>
void Tensor_Add(const Context& dev_ctx,
const DenseTensor& src1,
const DenseTensor& src2,
DenseTensor* out) {
out->Resize(src1.dims());
dev_ctx.template Alloc<T>(out);
phi::AddRawKernel<T, Context>(dev_ctx, src1, src2, -1, out);
}
template <typename Context, typename T>
void Tensor_Sub(const Context& dev_ctx,
const DenseTensor& src1,
const DenseTensor& src2,
DenseTensor* out) {
out->Resize(src1.dims());
dev_ctx.template Alloc<T>(out);
phi::SubtractRawKernel<T, Context>(dev_ctx, src1, src2, -1, out);
}
template <typename Context, typename T, size_t D>
void SliceCompute(const Context& dev_ctx,
const DenseTensor* in,
DenseTensor* out,
const std::vector<int>& axes_int,
const std::vector<int>& starts_int,
const std::vector<int>& ends_int) {
std::vector<int64_t> axes(axes_int.begin(), axes_int.end());
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int> decrease_axis = {};
std::vector<int> infer_flags = {};
PADDLE_ENFORCE_EQ(
starts.size(),
axes.size(),
phi::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(ends.size(),
axes.size(),
phi::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
// Step 2: Compute output
auto in_dims = in->dims();
auto out_dims = out->dims();
auto slice_dims = out_dims;
// 2.1 Infer output dims
for (size_t i = 0; i < axes.size(); ++i) {
// when start == -1 && end == start+1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = in_dims[axes[i]];
}
}
}
phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims = phi::funcs::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = slice_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
offsets[axes[i]] = starts[i];
}
out->Resize(slice_dims);
dev_ctx.template Alloc<T>(out);
auto in_t = EigenTensor<T, D>::From(*in, in_dims);
auto out_t = EigenTensor<T, D>::From(*out, slice_dims);
auto& eigen_place = *dev_ctx.eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
funcs::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place,
To32BitIndex(out_t),
To32BitIndex(in_t),
offsets_32bit,
extents_32bit);
} else {
funcs::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
}
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
}
template <typename Context, typename T>
void Tensor_narrow(const Context& dev_ctx,
const DenseTensor* src,
DenseTensor* out,
int row_s,
int row_e,
int col_s,
int col_e) {
auto rank = src->dims().size();
std::vector<int> axes_int = {rank - 2, rank - 1};
std::vector<int> starts_int = {row_s, col_s};
std::vector<int> ends_int = {row_e, col_e};
switch (rank) {
case 1:
SliceCompute<Context, T, 1>(
dev_ctx, src, out, axes_int, starts_int, ends_int);
break;
case 2:
SliceCompute<Context, T, 2>(
dev_ctx, src, out, axes_int, starts_int, ends_int);
break;
case 3:
SliceCompute<Context, T, 3>(
dev_ctx, src, out, axes_int, starts_int, ends_int);
break;
case 4:
SliceCompute<Context, T, 4>(
dev_ctx, src, out, axes_int, starts_int, ends_int);
break;
case 5:
SliceCompute<Context, T, 5>(
dev_ctx, src, out, axes_int, starts_int, ends_int);
break;
case 6:
SliceCompute<Context, T, 6>(
dev_ctx, src, out, axes_int, starts_int, ends_int);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
template <typename Context>
void arange(const Context& dev_ctx,
DenseTensor* tmp,
int w,
int batchsize = 1,
int h = 1) {
tmp->Resize(phi::make_ddim({batchsize * w}));
dev_ctx.template HostAlloc<int32_t>(tmp);
auto tmpdata = tmp->data<int32_t>();
for (int b = 0; b < batchsize; b++) {
for (int i = 0; i < w; i++) {
tmpdata[b * w + i] = static_cast<int32_t>(b * h + i);
}
}
}
template <typename T>
struct OneFunctor {
OneFunctor(T* output, int* idtptr, int w, int dim)
: output_(output), idtptr_(idtptr), w_(w), dim_(dim) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[w_ * idtptr_[idx] + idx % dim_] = static_cast<T>(1);
}
T* output_;
int* idtptr_;
int w_;
int dim_;
};
template <typename Context, typename T>
void LU_Unpack(const Context& dev_ctx,
const DenseTensor* LU,
DenseTensor* L,
DenseTensor* U) {
const auto udims = LU->dims();
L->Resize(udims);
U->Resize(udims);
const auto H = udims[udims.size() - 2];
const auto W = udims[udims.size() - 1];
dev_ctx.template Alloc<T>(L);
auto L_dataptr = L->data<T>();
phi::funcs::ForRange<Context> x_for_range(dev_ctx, LU->numel());
phi::funcs::TrilTriuCompute<T> tril_computer(
LU->data<T>(), -1, true, H, W, L_dataptr);
x_for_range(tril_computer);
dev_ctx.template Alloc<T>(U);
phi::funcs::TrilTriuCompute<T> triu_computer(
LU->data<T>(), 0, false, H, W, U->data<T>());
x_for_range(triu_computer);
// set L's diagonal 1
auto dim = std::min(H, W);
DenseTensor rowtensor, rt_dev;
auto batchsize = product(phi::slice_ddim(udims, 0, udims.size() - 2));
batchsize = std::max(static_cast<int>(batchsize), 1);
arange<Context>(dev_ctx, &rowtensor, dim, batchsize, H);
auto idtptr = rowtensor.data<int32_t>();
if (phi::AllocationType::GPU == dev_ctx.GetPlace().GetType()) {
phi::Copy(dev_ctx, rowtensor, dev_ctx.GetPlace(), false, &rt_dev);
idtptr = rt_dev.data<int32_t>();
}
phi::funcs::ForRange<Context> for_range(dev_ctx, rowtensor.numel());
OneFunctor<T> functor(L_dataptr, idtptr, W, dim);
for_range(functor);
}
template <typename Context, typename T>
void scatterpivot(
const Context& dev_ctx, T* out_data, DenseTensor* idlst, int w, int dim) {
DenseTensor idlst_tmp;
idlst_tmp.Resize(idlst->dims());
dev_ctx.template Alloc<int32_t>(&idlst_tmp);
phi::Copy(dev_ctx, *idlst, dev_ctx.GetPlace(), false, &idlst_tmp);
auto idtptr = idlst_tmp.data<int32_t>();
phi::funcs::ForRange<Context> for_range(dev_ctx, idlst_tmp.numel());
OneFunctor<T> functor(out_data, idtptr, w, dim);
for_range(functor);
}
template <typename Context, typename T>
void Unpack_Pivot(const Context& dev_ctx,
const DenseTensor& Pivot,
DenseTensor* P,
int h,
int w) {
auto dims = Pivot.dims();
auto Pdimvec = vectorize(dims);
auto prank = Pdimvec.size();
auto Pnum = dims[prank - 1];
DenseTensor Pivot_cpu;
phi::CPUPlace cpu;
phi::Copy(dev_ctx, Pivot, cpu, false, &Pivot_cpu);
auto pdataptr = Pivot_cpu.data<int32_t>();
Pdimvec[prank - 1] = h;
Pdimvec.emplace_back(h);
auto Pdim = phi::make_ddim(Pdimvec);
P->Resize(Pdim);
dev_ctx.template Alloc<T>(P);
auto pdata = P->data<T>();
phi::funcs::SetConstant<Context, T> setter;
setter(dev_ctx, P, static_cast<T>(0));
auto batchsize = product(phi::slice_ddim(dims, 0, prank - 1));
batchsize = std::max(static_cast<int>(batchsize), 1);
DenseTensor idt;
for (int i = 0; i < batchsize; i++) {
arange<Context>(dev_ctx, &idt, h);
auto idlst = idt.data<int32_t>();
for (int j = 0; j < Pnum; j++) {
if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue;
auto temp = idlst[j];
idlst[j] = idlst[pdataptr[i * Pnum + j] - 1];
idlst[pdataptr[i * Pnum + j] - 1] = temp;
}
scatterpivot(dev_ctx, &(pdata[i * h * h]), &idt, h, h);
}
}
template <typename Context, typename T>
DenseTensor Transpose2DTo6D(const Context& dev_ctx, const DenseTensor& x) {
// transpose the last two dimision
DenseTensor ret;
auto x_dim = x.dims();
auto x_vec = phi::vectorize<int>(x_dim);
int rank = x_vec.size();
std::swap(x_vec[rank - 1], x_vec[rank - 2]);
std::vector<int> out_shape = x_vec;
std::vector<int> axis(rank);
for (int i = 0; i < rank; ++i) {
axis[i] = i;
}
std::swap(axis[rank - 1], axis[rank - 2]);
ret.Resize(phi::make_ddim(x_vec));
dev_ctx.template Alloc<T>(&ret);
switch (rank) {
case 2: {
phi::funcs::Transpose<Context, T, 2> trans;
trans(dev_ctx, x, &ret, axis);
break;
}
case 3: {
phi::funcs::Transpose<Context, T, 3> trans;
trans(dev_ctx, x, &ret, axis);
break;
}
case 4: {
phi::funcs::Transpose<Context, T, 4> trans;
trans(dev_ctx, x, &ret, axis);
break;
}
case 5: {
phi::funcs::Transpose<Context, T, 5> trans;
trans(dev_ctx, x, &ret, axis);
break;
}
case 6: {
phi::funcs::Transpose<Context, T, 6> trans;
trans(dev_ctx, x, &ret, axis);
break;
}
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"Invalid Rank number, "
"currently only support rank between 2~6"));
}
}
return ret;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LUGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& pivots,
const DenseTensor& out_grad,
bool pivot,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LUKernel(const Context& dev_ctx,
const DenseTensor& x,
bool pivot,
DenseTensor* out,
DenseTensor* pivots,
DenseTensor* infos);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LUOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lu", {"X"}, {"pivots"}, {"Out", "Pivots", "Infos"});
}
KernelSignature LUGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"lu_grad", {"X", "Out", "Pivots", "Out@GRAD"}, {"pivots"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lu, phi::LUOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(lu_grad, phi::LUGradOpArgumentMapping);
......@@ -128,6 +128,8 @@ class TestLUOp(OpTest):
def setUp(self):
self.op_type = "lu"
self.python_api = paddle.tensor.linalg.lu
self.python_out_sig = ["Out", "Pivots"]
self.config()
self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)}
......@@ -140,10 +142,10 @@ class TestLUOp(OpTest):
}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], ['Out'])
self.check_grad(['X'], ['Out'], check_eager=True)
# m = n 2D
......
......@@ -2102,27 +2102,27 @@ def lu(x, pivot=True, get_infos=False, name=None):
# one can verify : X = P @ L @ U ;
"""
if paddle.in_dynamic_mode():
LU, Piv, Info = _C_ops.lu(x, 'pivots', pivot)
if get_infos:
return LU, Piv, Info
else:
return LU, Piv
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu')
helper = LayerHelper('lu', **locals())
lu = helper.create_variable_for_type_inference(dtype=x.dtype)
p = helper.create_variable_for_type_inference(dtype='int')
info = helper.create_variable_for_type_inference(dtype='int')
attrs = dict()
attrs['pivots'] = pivot
helper.append_op(type='lu',
inputs={'X': x},
outputs={
'Out': lu,
'Pivots': p,
'Infos': info
},
attrs=attrs)
if in_dygraph_mode():
lu, p, info = _C_ops.final_state_lu(x, pivot)
elif paddle.in_dynamic_mode():
lu, p, info = _C_ops.lu(x, 'pivot', pivot)
else:
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu')
helper = LayerHelper('lu', **locals())
lu = helper.create_variable_for_type_inference(dtype=x.dtype)
p = helper.create_variable_for_type_inference(dtype='int')
info = helper.create_variable_for_type_inference(dtype='int')
attrs = dict()
attrs['pivot'] = pivot
helper.append_op(type='lu',
inputs={'X': x},
outputs={
'Out': lu,
'Pivots': p,
'Infos': info
},
attrs=attrs)
if get_infos:
return lu, p, info
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册