未验证 提交 2ce91c33 编写于 作者: Z zhiboniu 提交者: GitHub

add new API paddle.linalg.lu/lu_unpack (#38617)

上级 89ce6db8
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lu_unpack_op.h"
namespace paddle {
namespace operators {
class LU_UnpackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(Unpack L U and P to single matrix tensor,
unpack L and U matrix from LU, unpack permutation matrix Pmat from Pivtos .
)DOC");
AddInput("X", "(Tensor) The input LU tensor, shape of (*,m,n)");
AddInput("Pivots",
"(Tensor) The input Pivots tensor, shape of (*,min(m,n))");
AddOutput(
"Pmat",
"(Tensor) The output permutation matrix tensor, shape of (*, m, m)");
AddOutput("L", "(Tensor) The output lower triangular matrix tensor");
AddOutput("U", "(Tensor) The output upper triangular matrix tensor");
AddAttr<bool>("unpack_ludata", "Whether to unpack L and U")
.SetDefault(true);
AddAttr<bool>("unpack_pivots", "Whether to unpack permutation matrix")
.SetDefault(true);
}
};
class LU_UnpackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LU_Unpack");
OP_INOUT_CHECK(context->HasInput("Pivots"), "Input", "Pivots", "LU_Unpack");
OP_INOUT_CHECK(context->HasOutput("L"), "Output", "L", "LU_Unpack");
OP_INOUT_CHECK(context->HasOutput("U"), "Output", "U", "LU_Unpack");
OP_INOUT_CHECK(context->HasOutput("Pmat"), "Output", "Pmat", "LU_Unpack");
bool unpack_ludata = context->Attrs().Get<bool>("unpack_ludata");
bool unpack_pivots = context->Attrs().Get<bool>("unpack_pivots");
auto x_dims = context->GetInputDim("X");
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_rank, 2, platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
// context->SetOutputDim("Out", x_dims);
int m = x_dims[x_rank - 1];
int n = x_dims[x_rank - 2];
int min_mn = std::min(m, n);
if (unpack_ludata) {
auto ldims = x_dims;
auto udims = x_dims;
if (m >= n) {
udims[x_rank - 2] = min_mn;
} else {
ldims[x_rank - 1] = min_mn;
}
context->SetOutputDim("U", udims);
context->SetOutputDim("L", ldims);
}
if (unpack_pivots) {
auto pdims = x_dims;
pdims[x_rank - 1] = m;
context->SetOutputDim("Pmat", pdims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class LU_UnpackOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = ctx->GetInputType("X", 0);
auto data_type = ctx->GetInputDataType("X", 0);
ctx->SetOutputType("L", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("L", data_type, framework::ALL_ELEMENTS);
ctx->SetOutputType("U", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("U", data_type, framework::ALL_ELEMENTS);
ctx->SetOutputType("Pmat", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Pmat", data_type, framework::ALL_ELEMENTS);
}
};
template <typename T>
class LU_UnpackOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("lu_unpack_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Pivots", this->Input("Pivots"));
retv->SetInput("L", this->Output("L"));
retv->SetInput("U", this->Output("U"));
retv->SetInput("Pmat", this->Output("Pmat"));
retv->SetInput(framework::GradVarName("L"), this->OutputGrad("L"));
retv->SetInput(framework::GradVarName("U"), this->OutputGrad("U"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
class LU_UnpackGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = ctx->GetInputType("X", 0);
auto data_type = ctx->GetInputDataType("X", 0);
ctx->SetOutputType(framework::GradVarName("X"), var_type,
framework::ALL_ELEMENTS);
ctx->SetOutputDataType(framework::GradVarName("X"), data_type,
framework::ALL_ELEMENTS);
}
};
class LU_UnpackGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu_unpack");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("L")), "Input",
"L@GRAD", "lu_unpack");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("U")), "Input",
"U@GRAD", "lu_unpack");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(lu_unpack, ops::LU_UnpackOp, ops::LU_UnpackOpMaker,
ops::LU_UnpackOpVarTypeInference,
ops::LU_UnpackOpGradMaker<paddle::framework::OpDesc>,
ops::LU_UnpackOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lu_unpack_grad, ops::LU_UnpackGradOp,
ops::LU_UnpackGradOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(lu_unpack,
ops::LU_UnpackKernel<plat::CPUDeviceContext, float>,
ops::LU_UnpackKernel<plat::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
lu_unpack_grad, ops::LU_UnpackGradKernel<plat::CPUDeviceContext, float>,
ops::LU_UnpackGradKernel<plat::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/lu_unpack_op.h"
namespace paddle {
namespace operators {} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lu_unpack,
ops::LU_UnpackKernel<plat::CUDADeviceContext, float>,
ops::LU_UnpackKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
lu_unpack_grad, ops::LU_UnpackGradKernel<plat::CUDADeviceContext, float>,
ops::LU_UnpackGradKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensorArray = framework::LoDTensorArray;
template <typename DeviceContext, typename T>
class LU_UnpackKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto xin = ctx.Input<framework::Tensor>("X");
auto P = ctx.Input<framework::Tensor>("Pivots");
auto ltensor = ctx.Output<framework::Tensor>("L");
auto utensor = ctx.Output<framework::Tensor>("U");
auto ptensor = ctx.Output<framework::Tensor>("Pmat");
auto unpack_ludata = ctx.Attr<bool>("unpack_ludata");
auto unpack_pivots = ctx.Attr<bool>("unpack_pivots");
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto xdims = xin->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
if (unpack_ludata) {
ltensor->mutable_data<T>(ctx.GetPlace());
utensor->mutable_data<T>(ctx.GetPlace());
framework::Tensor L, U;
LU_Unpack<DeviceContext, T>(dev_ctx, xin, &L, &U);
if (m >= n) {
framework::TensorCopy(L, ctx.GetPlace(), ltensor);
Tensor_narrow<DeviceContext, T>(ctx, &U, utensor, 0, k, 0, k);
} else {
framework::TensorCopy(U, ctx.GetPlace(), utensor);
Tensor_narrow<DeviceContext, T>(ctx, &L, ltensor, 0, k, 0, k);
}
}
if (unpack_pivots) {
ptensor->mutable_data<T>(ctx.GetPlace());
Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, ptensor, m, k);
}
}
};
template <typename DeviceContext, typename T>
class LU_UnpackGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto dl = ctx.Input<framework::Tensor>(framework::GradVarName("L"));
auto du = ctx.Input<framework::Tensor>(framework::GradVarName("U"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
framework::Tensor dl_tril, du_triu;
const auto ldims = dl->dims();
dl_tril.Resize(ldims);
auto H = ldims[ldims.size() - 2];
auto W = ldims[ldims.size() - 1];
auto L_dataptr = dl_tril.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> l_for_range(dev_ctx, dl->numel());
TrilTriuCompute<T> tril_computer(dl->data<T>(), -1, true, H, W, L_dataptr);
l_for_range(tril_computer);
const auto udims = du->dims();
du_triu.Resize(udims);
H = udims[udims.size() - 2];
W = udims[udims.size() - 1];
auto U_dataptr = du_triu.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> u_for_range(dev_ctx, du->numel());
TrilTriuCompute<T> triu_computer(du->data<T>(), 0, false, H, W, U_dataptr);
u_for_range(triu_computer);
auto xdims = dx->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
std::vector<int64_t> axes = {xrank - 2, xrank - 1};
std::vector<int64_t> slice_starts(2, 0);
std::vector<int64_t> slice_ends(2, 0);
auto valuedims = vectorize(xdims);
math::SetConstant<DeviceContext, T> setter;
setter(dev_ctx, dx, static_cast<T>(0));
if (m <= n) {
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx, dx, &dl_tril, dx, axes,
&slice_starts, &slice_ends,
valuedims, xrank);
Tensor_Add<DeviceContext, T>(dev_ctx, *dx, du_triu, dx);
} else {
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx, dx, &du_triu, dx, axes,
&slice_starts, &slice_ends,
valuedims, xrank);
Tensor_Add<DeviceContext, T>(dev_ctx, *dx, dl_tril, dx);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -170,5 +170,116 @@ class TestLUOp3(TestLUOp):
self.dtype = "float64"
class TestLUAPI(unittest.TestCase):
def test_dygraph(self):
def run_lu_dygraph(shape, dtype):
if dtype == "float32":
np_dtype = np.float32
elif dtype == "float64":
np_dtype = np.float64
a = np.random.rand(*shape).astype(np_dtype)
m = a.shape[-2]
n = a.shape[-1]
min_mn = min(m, n)
pivot = True
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
paddle.disable_static(place)
batch_size = a.size // (a.shape[-1] * a.shape[-2])
x = paddle.to_tensor(a, dtype=dtype)
sP, sl, sU = scipy_lu(a, pivot)
sL = np.tril(sl, -1)
LU, P, Info = paddle.linalg.lu(x, pivot=pivot, get_infos=True)
m, n = LU.shape[-2], LU.shape[-1]
tril = np.tril(LU, -1)[..., :m, :m]
triu = np.triu(LU)[..., :n, :n]
mtp = Pmat_to_perm(sP, min(m, n))
nP = perm_to_Pmat(P, sP.shape[-1])
self.assertTrue(np.allclose(sU, triu, atol=1e-5))
self.assertTrue(np.allclose(sL, tril, atol=1e-5))
self.assertTrue(np.allclose(P, mtp, atol=1e-5))
self.assertTrue(np.allclose(nP, sP, atol=1e-5))
tensor_shapes = [
(3, 5),
(5, 5),
(5, 3), # 2-dim Tensors
(2, 3, 5),
(3, 5, 5),
(4, 5, 3), # 3-dim Tensors
(2, 5, 3, 5),
(3, 5, 5, 5),
(4, 5, 5, 3) # 4-dim Tensors
]
dtypes = ["float32", "float64"]
for tensor_shape, dtype in itertools.product(tensor_shapes, dtypes):
run_lu_dygraph(tensor_shape, dtype)
def test_static(self):
paddle.enable_static()
def run_lu_static(shape, dtype):
if dtype == "float32":
np_dtype = np.float32
elif dtype == "float64":
np_dtype = np.float64
a = np.random.rand(*shape).astype(np_dtype)
m = a.shape[-2]
n = a.shape[-1]
min_mn = min(m, n)
pivot = True
places = []
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.program_guard(fluid.Program(), fluid.Program()):
batch_size = a.size // (a.shape[-1] * a.shape[-2])
sP, sl, sU = scipy_lu(a, pivot)
sL = np.tril(sl, -1)
ashape = np.array(a.shape)
lshape = np.array(sL.shape)
ushape = np.array(sU.shape)
lpad = (len(sL.shape) - 2) * [(0, 0)] + list((
(0, (ashape - lshape)[-2]), (0, (ashape - lshape)[-1])))
upad = (len(sU.shape) - 2) * [(0, 0)] + list((
(0, (ashape - ushape)[-2]), (0, (ashape - ushape)[-1])))
NsL = np.pad(sL, lpad)
NsU = np.pad(sU, upad)
NLU = NsL + NsU
x = paddle.fluid.data(
name="input", shape=shape, dtype=dtype)
lu, p = paddle.linalg.lu(x, pivot=pivot)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": a},
fetch_list=[lu, p])
self.assertTrue(np.allclose(fetches[0], NLU, atol=1e-5))
tensor_shapes = [
(3, 5),
(5, 5),
(5, 3), # 2-dim Tensors
(2, 3, 5),
(3, 5, 5),
(4, 5, 3), # 3-dim Tensors
(2, 5, 3, 5),
(3, 5, 5, 5),
(4, 5, 5, 3) # 4-dim Tensors
]
dtypes = ["float32", "float64"]
for tensor_shape, dtype in itertools.product(tensor_shapes, dtypes):
run_lu_static(tensor_shape, dtype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from op_test import OpTest
import unittest
import itertools
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import scipy
import scipy.linalg
import copy
def scipy_lu_unpack(A):
shape = A.shape
if len(shape) == 2:
return scipy.linalg.lu(A)
else:
preshape = shape[:-2]
batchsize = np.product(shape) // (shape[-2] * shape[-1])
Plst = []
Llst = []
Ulst = []
NA = A.reshape((-1, shape[-2], shape[-1]))
for b in range(batchsize):
As = NA[b]
P, L, U = scipy.linalg.lu(As)
pshape = P.shape
lshape = L.shape
ushape = U.shape
Plst.append(P)
Llst.append(L)
Ulst.append(U)
return np.array(Plst).reshape(preshape + pshape), np.array(
Llst).reshape(preshape + lshape), np.array(Ulst).reshape(preshape +
ushape)
def Pmat_to_perm(Pmat_org, cut):
Pmat = copy.deepcopy(Pmat_org)
shape = Pmat.shape
rows = shape[-2]
cols = shape[-1]
batchsize = max(1, np.product(shape[:-2]))
P = Pmat.reshape(batchsize, rows, cols)
permmat = []
for b in range(batchsize):
permlst = []
sP = P[b]
for c in range(min(rows, cols)):
idx = np.argmax(sP[:, c])
permlst.append(idx)
tmp = copy.deepcopy(sP[c, :])
sP[c, :] = sP[idx, :]
sP[idx, :] = tmp
permmat.append(permlst)
Pivot = np.array(permmat).reshape(list(shape[:-2]) + [rows, ]) + 1
return Pivot[..., :cut]
def perm_to_Pmat(perm, dim):
pshape = perm.shape
bs = int(np.product(perm.shape[:-1]).item())
perm = perm.reshape((bs, pshape[-1]))
oneslst = []
for i in range(bs):
idlst = np.arange(dim)
perm_item = perm[i, :]
for idx, p in enumerate(perm_item - 1):
temp = idlst[idx]
idlst[idx] = idlst[p]
idlst[p] = temp
ones = paddle.eye(dim)
nmat = paddle.scatter(ones, paddle.to_tensor(idlst), ones)
oneslst.append(nmat)
return np.array(oneslst).reshape(list(pshape[:-1]) + [dim, dim])
# m > n
class TestLU_UnpackOp(OpTest):
"""
case 1
"""
def config(self):
self.x_shape = [2, 12, 10]
self.unpack_ludata = True
self.unpack_pivots = True
self.dtype = "float64"
def set_output(self, A):
sP, sL, sU = scipy_lu_unpack(A)
self.L = sL
self.U = sU
self.P = sP
def setUp(self):
self.op_type = "lu_unpack"
self.config()
x = np.random.random(self.x_shape).astype(self.dtype)
if paddle.in_dynamic_mode():
xt = paddle.to_tensor(x)
lu, pivots = paddle.linalg.lu(xt)
lu = lu.numpy()
pivots = pivots.numpy()
else:
with fluid.program_guard(fluid.Program(), fluid.Program()):
place = fluid.CPUPlace()
if core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
xv = paddle.fluid.data(
name="input", shape=self.x_shape, dtype=self.dtype)
lu, p = paddle.linalg.lu(xv)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": x},
fetch_list=[lu, p])
lu, pivots = fetches[0], fetches[1]
self.inputs = {'X': lu, 'Pivots': pivots}
self.attrs = {
'unpack_ludata': self.unpack_ludata,
'unpack_pivots': self.unpack_pivots
}
self.set_output(x)
self.outputs = {
'Pmat': self.P,
'L': self.L,
'U': self.U,
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['L', 'U'])
# m = n
class TestLU_UnpackOp2(TestLU_UnpackOp):
"""
case 2
"""
def config(self):
self.x_shape = [2, 10, 10]
self.unpack_ludata = True
self.unpack_pivots = True
self.dtype = "float64"
# m < n
class TestLU_UnpackOp3(TestLU_UnpackOp):
"""
case 3
"""
def config(self):
self.x_shape = [2, 10, 12]
self.unpack_ludata = True
self.unpack_pivots = True
self.dtype = "float64"
class TestLU_UnpackAPI(unittest.TestCase):
def test_dygraph(self):
def run_lu_unpack_dygraph(shape, dtype):
if dtype == "float32":
np_dtype = np.float32
elif dtype == "float64":
np_dtype = np.float64
a = np.random.rand(*shape).astype(np_dtype)
m = a.shape[-2]
n = a.shape[-1]
min_mn = min(m, n)
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
paddle.disable_static(place)
x = paddle.to_tensor(a, dtype=dtype)
sP, sL, sU = scipy_lu_unpack(a)
LU, P = paddle.linalg.lu(x)
pP, pL, pU = paddle.linalg.lu_unpack(LU, P)
self.assertTrue(np.allclose(sU, pU, atol=1e-5))
self.assertTrue(np.allclose(sL, pL, atol=1e-5))
self.assertTrue(np.allclose(sP, pP, atol=1e-5))
tensor_shapes = [
(3, 5),
(5, 5),
(5, 3), # 2-dim Tensors
(2, 3, 5),
(3, 5, 5),
(4, 5, 3), # 3-dim Tensors
(2, 5, 3, 5),
(3, 5, 5, 5),
(4, 5, 5, 3) # 4-dim Tensors
]
dtypes = ["float32", "float64"]
for tensor_shape, dtype in itertools.product(tensor_shapes, dtypes):
run_lu_unpack_dygraph(tensor_shape, dtype)
def test_static(self):
paddle.enable_static()
def run_lu_static(shape, dtype):
if dtype == "float32":
np_dtype = np.float32
elif dtype == "float64":
np_dtype = np.float64
a = np.random.rand(*shape).astype(np_dtype)
m = a.shape[-2]
n = a.shape[-1]
min_mn = min(m, n)
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.program_guard(fluid.Program(), fluid.Program()):
sP, sL, sU = scipy_lu_unpack(a)
x = paddle.fluid.data(
name="input", shape=shape, dtype=dtype)
lu, p = paddle.linalg.lu(x)
pP, pL, pU = paddle.linalg.lu_unpack(lu, p)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": a},
fetch_list=[pP, pL, pU])
self.assertTrue(np.allclose(fetches[0], sP, atol=1e-5))
self.assertTrue(np.allclose(fetches[1], sL, atol=1e-5))
self.assertTrue(np.allclose(fetches[2], sU, atol=1e-5))
tensor_shapes = [
(3, 5),
(5, 5),
(5, 3), # 2-dim Tensors
(2, 3, 5),
(3, 5, 5),
(4, 5, 3), # 3-dim Tensors
(2, 5, 3, 5),
(3, 5, 5, 5),
(4, 5, 5, 3) # 4-dim Tensors
]
dtypes = ["float32", "float64"]
for tensor_shape, dtype in itertools.product(tensor_shapes, dtypes):
run_lu_static(tensor_shape, dtype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -27,6 +27,8 @@ from .tensor.linalg import matrix_rank # noqa: F401
from .tensor.linalg import svd # noqa: F401
from .tensor.linalg import eigvalsh # noqa: F401
from .tensor.linalg import qr # noqa: F401
from .tensor.linalg import lu # noqa: F401
from .tensor.linalg import lu_unpack # noqa: F401
from .tensor.linalg import eigh # noqa: F401
from .tensor.linalg import det # noqa: F401
from .tensor.linalg import slogdet # noqa: F401
......@@ -46,6 +48,8 @@ __all__ = [
'matrix_rank',
'svd',
'qr',
'lu',
'lu_unpack',
'matrix_power',
'det',
'slogdet',
......
......@@ -63,6 +63,8 @@ from .linalg import eigh # noqa: F401
from .linalg import pinv # noqa: F401
from .linalg import solve # noqa: F401
from .linalg import cholesky_solve # noqa: F401
from .linalg import lu # noqa: F401
from .linalg import lu_unpack # noqa: F401
from .logic import equal # noqa: F401
from .logic import greater_equal # noqa: F401
from .logic import greater_than # noqa: F401
......@@ -459,6 +461,8 @@ tensor_method_func = [ #noqa
'asinh',
'atanh',
'acosh',
'lu',
'lu_unpack',
'as_complex',
'as_real',
'rad2deg',
......
......@@ -1823,6 +1823,205 @@ def qr(x, mode="reduced", name=None):
return q, r
def lu(x, pivot=True, get_infos=False, name=None):
r"""
Computes the LU factorization of an N-D(N>=2) matrix x.
Returns the LU factorization(inplace x) and Pivots. low triangular matrix L and
upper triangular matrix U are combined to a single LU matrix.
Pivoting is done if pivot is set to True.
P mat can be get by pivots:
# ones = eye(rows) #eye matrix of rank rows
# for i in range(cols):
# swap(ones[i], ones[pivots[i]])
# return ones
Args:
X (Tensor): the tensor to factor of N-dimensions(N>=2).
pivot (bool, optional): controls whether pivoting is done. Default: True.
get_infos (bool, optional): if set to True, returns an info IntTensor. Default: False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
factorization (Tensor): LU matrix, the factorization of input X.
pivots (IntTensor): the pivots of size(∗(N-2), min(m,n)). `pivots` stores all the
intermediate transpositions of rows. The final permutation `perm` could be
reconstructed by this, details refer to upper example.
infos (IntTensor, optional): if `get_infos` is `True`, this is a tensor of size (∗(N-2))
where non-zero values indicate whether factorization for the matrix or each minibatch
has succeeded or failed.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]).astype('float64')
lu,p,info = paddle.linalg.lu(x, get_infos=True)
# >>> lu:
# Tensor(shape=[3, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[5. , 6. ],
# [0.20000000, 0.80000000],
# [0.60000000, 0.50000000]])
# >>> p
# Tensor(shape=[2], dtype=int32, place=CUDAPlace(0), stop_gradient=True,
# [3, 3])
# >>> info
# Tensor(shape=[], dtype=int32, place=CUDAPlace(0), stop_gradient=True,
# 0)
P,L,U = paddle.linalg.lu_unpack(lu,p)
# >>> P
# (Tensor(shape=[3, 3], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [1., 0., 0.]]),
# >>> L
# Tensor(shape=[3, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[1. , 0. ],
# [0.20000000, 1. ],
# [0.60000000, 0.50000000]]),
# >>> U
# Tensor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[5. , 6. ],
# [0. , 0.80000000]]))
# one can verify : X = P @ L @ U ;
"""
if in_dygraph_mode():
LU, Piv, Info = _C_ops.lu(x, 'pivots', pivot)
if get_infos:
return LU, Piv, Info
else:
return LU, Piv
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu')
helper = LayerHelper('lu', **locals())
lu = helper.create_variable_for_type_inference(dtype=x.dtype)
p = helper.create_variable_for_type_inference(dtype='int')
info = helper.create_variable_for_type_inference(dtype='int')
attrs = dict()
attrs['pivots'] = pivot
helper.append_op(
type='lu',
inputs={'X': x},
outputs={'Out': lu,
'Pivots': p,
'Infos': info},
attrs=attrs)
if get_infos:
return lu, p, info
else:
return lu, p
def lu_unpack(x, y, unpack_ludata=True, unpack_pivots=True, name=None):
r"""
Unpack L U and P to single matrix tensor .
unpack L and U matrix from LU, unpack permutation matrix P from Pivtos .
P mat can be get by pivots:
# ones = eye(rows) #eye matrix of rank rows
# for i in range(cols):
# swap(ones[i], ones[pivots[i]])
Args:
x (Tensor): The LU tensor get from paddle.linalg.lu, which is combined by L and U.
y (Tensor): Pivots get from paddle.linalg.lu.
unpack_ludata (bool,optional): whether to unpack L and U from x. Default: True.
unpack_pivots (bool, optional): whether to unpack permutation matrix P from Pivtos. Default: True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
P (Tensor): Permutation matrix P of lu factorization.
L (Tensor): The lower triangular matrix tensor of lu factorization.
U (Tensor): The upper triangular matrix tensor of lu factorization.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]).astype('float64')
lu,p,info = paddle.linalg.lu(x, get_infos=True)
# >>> lu:
# Tensor(shape=[3, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[5. , 6. ],
# [0.20000000, 0.80000000],
# [0.60000000, 0.50000000]])
# >>> p
# Tensor(shape=[2], dtype=int32, place=CUDAPlace(0), stop_gradient=True,
# [3, 3])
# >>> info
# Tensor(shape=[], dtype=int32, place=CUDAPlace(0), stop_gradient=True,
# 0)
P,L,U = paddle.linalg.lu_unpack(lu,p)
# >>> P
# (Tensor(shape=[3, 3], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [1., 0., 0.]]),
# >>> L
# Tensor(shape=[3, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[1. , 0. ],
# [0.20000000, 1. ],
# [0.60000000, 0.50000000]]),
# >>> U
# Tensor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[5. , 6. ],
# [0. , 0.80000000]]))
# one can verify : X = P @ L @ U ;
"""
if in_dygraph_mode():
P, L, U = _C_ops.lu_unpack(x, y, 'unpack_ludata', unpack_ludata,
'unpack_pivots', unpack_pivots)
return P, L, U
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu_unpack')
helper = LayerHelper('lu_unpack', **locals())
p = helper.create_variable_for_type_inference(dtype=x.dtype)
l = helper.create_variable_for_type_inference(dtype=x.dtype)
u = helper.create_variable_for_type_inference(dtype=x.dtype)
attrs = dict()
attrs['unpack_ludata'] = unpack_ludata
attrs['unpack_pivots'] = unpack_pivots
helper.append_op(
type='lu_unpack',
inputs={'X': x,
'Pivots': y},
outputs={'Pmat': p,
'L': l,
'U': u},
attrs=attrs)
return p, l, u
def eig(x, name=None):
"""
This API performs the eigenvalue decomposition of a square matrix or a batch of square matrices.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册