未验证 提交 e24ca55e 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Phi]migrate cholesky_solve op to phi (#40387)

上级 dc773828
......@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cholesky_solve_op.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -39,50 +40,6 @@ class CholeskySolveOpMaker : public framework::OpProtoAndCheckerMaker {
class CholeskySolveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "CholeskySolve");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "CholeskySolve");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "CholeskySolve");
auto u_dims = context->GetInputDim("Y");
auto b_dims = context->GetInputDim("X");
int u_rank = u_dims.size();
int b_rank = b_dims.size();
PADDLE_ENFORCE_GE(u_rank, 2,
platform::errors::InvalidArgument(
"the rank of input Y must greater or equal to 2"));
PADDLE_ENFORCE_GE(b_rank, 2,
platform::errors::InvalidArgument(
"the rank of input X must greater or equal to 2"));
PADDLE_ENFORCE_EQ(u_dims[u_rank - 1], u_dims[u_rank - 2],
platform::errors::InvalidArgument(
"input Matrix Y should be square matrix,"
"But Got last shape of %ld x %ld",
u_dims[u_rank - 1], u_dims[u_rank - 2]));
PADDLE_ENFORCE_EQ(
b_dims[b_rank - 2], u_dims[u_rank - 2],
platform::errors::InvalidArgument(
"the first dim of input X must equal to the dim of input Y,"
"But Got %ld and %ld",
b_dims[b_rank - 2], u_dims[u_rank - 2]));
std::vector<int64_t> u_dims_vec = phi::vectorize(u_dims);
std::vector<int64_t> b_dims_vec = phi::vectorize(b_dims);
std::vector<int64_t> u_dims_vec_cut(u_dims_vec.begin(),
u_dims_vec.end() - 2);
std::vector<int64_t> b_dims_vec_cut(b_dims_vec.begin(),
b_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(u_dims_vec_cut, b_dims_vec_cut);
std::vector<int64_t> b_broadcast_dims({expand_batch_portion});
b_broadcast_dims.insert(b_broadcast_dims.end(),
{b_dims_vec[b_rank - 2], b_dims_vec[b_rank - 1]});
// dim of 'Out' is the same with 'Y' after broadcast
context->SetOutputDim("Out", phi::make_ddim(b_broadcast_dims));
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -151,22 +108,15 @@ class CholeskySolveGradOp : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(cholesky_solve, CholeskySolveInferShapeFunctor,
PD_INFER_META(phi::CholeskySolveInferMeta));
REGISTER_OPERATOR(cholesky_solve, ops::CholeskySolveOp,
ops::CholeskySolveOpMaker,
ops::CholeskySolveOpVarTypeInference,
ops::CholeskySolveOpGradMaker<paddle::framework::OpDesc>,
ops::CholeskySolveOpGradMaker<paddle::imperative::OpBase>);
ops::CholeskySolveOpGradMaker<paddle::imperative::OpBase>,
CholeskySolveInferShapeFunctor);
REGISTER_OPERATOR(cholesky_solve_grad, ops::CholeskySolveGradOp);
REGISTER_OP_CPU_KERNEL(
cholesky_solve,
ops::CholeskySolveKernel<paddle::platform::CPUDeviceContext, float>,
ops::CholeskySolveKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
cholesky_solve_grad,
ops::CholeskySolveGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CholeskySolveGradKernel<paddle::platform::CPUDeviceContext, double>);
// Complex<> is not supported because of TensorExpand, which used to boardcast
// input Tensor
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/cholesky_solve_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template <typename T>
void cusolver_potrs(const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo,
int n, int nrhs, T *Adata, int lda, T *Bdata, int ldb,
int *devInfo);
template <>
void cusolver_potrs<float>(const cusolverDnHandle_t &cusolverH,
cublasFillMode_t uplo, int n, int nrhs, float *Adata,
int lda, float *Bdata, int ldb, int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSpotrs(
cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo));
}
template <>
void cusolver_potrs<double>(const cusolverDnHandle_t &cusolverH,
cublasFillMode_t uplo, int n, int nrhs,
double *Adata, int lda, double *Bdata, int ldb,
int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDpotrs(
cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo));
}
template <>
void cusolver_potrs<platform::complex<float>>(
const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs,
platform::complex<float> *Adata, int lda, platform::complex<float> *Bdata,
int ldb, int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnCpotrs(
cusolverH, uplo, n, nrhs, reinterpret_cast<const cuComplex *>(Adata), lda,
reinterpret_cast<cuComplex *>(Bdata), ldb, devInfo));
}
template <>
void cusolver_potrs<platform::complex<double>>(
const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs,
platform::complex<double> *Adata, int lda, platform::complex<double> *Bdata,
int ldb, int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnZpotrs(
cusolverH, uplo, n, nrhs,
reinterpret_cast<const cuDoubleComplex *>(Adata), lda,
reinterpret_cast<cuDoubleComplex *>(Bdata), ldb, devInfo));
}
template <typename T>
class CholeskySolveFunctor<paddle::platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext &dev_ctx, bool upper, int n,
int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) {
cublasFillMode_t uplo =
upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
/* step 1: get cusolver handle*/
auto cusolverH = dev_ctx.cusolver_dn_handle();
/* step 2: solve A0*X0 = B0 */
cusolver_potrs<T>(cusolverH, uplo, n, nrhs, Adata, lda, Bdata, lda,
devInfo);
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
}
};
template <typename T>
class MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const Tensor &in, Tensor *out,
const framework::ExecutionContext &ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = phi::vectorize(in.dims());
auto in_size = in_dims.size();
const std::vector<std::int64_t> out_dims = phi::vectorize(out->dims());
auto out_size = out_dims.size();
std::vector<std::int64_t> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(), out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx.cuda_device_context(), in, out, kps::IdentityFunctor<T>(),
out_reduce_dims, stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cholesky_solve,
ops::CholeskySolveKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskySolveKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
cholesky_solve_grad,
ops::CholeskySolveGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskySolveGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif // not PADDLE_WITH_HIP
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle {
namespace operators { // namespace operators
template <typename DeviceContext, typename T>
class CholeskySolveFunctor {
public:
void operator()(const platform::DeviceContext &dev_ctx, bool upper, int n,
int nrhs, T *Adata, int lda, T *Bdata, int *devInfo);
};
template <typename T>
class CholeskySolveFunctor<paddle::platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext &dev_ctx, bool upper, int n,
int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) {
char uplo = upper ? 'U' : 'L';
phi::funcs::lapackCholeskySolve<T>(uplo, n, nrhs, Adata, lda, Bdata, lda,
devInfo);
}
};
template <typename DeviceContext, typename T>
void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
const framework::Tensor &uin,
const framework::Tensor &bin, framework::Tensor *out,
bool upper) {
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
// framework::Tensor broadcast
std::vector<int64_t> u_bst_dims_vec;
std::vector<int64_t> b_bst_dims_vec;
std::tie(u_bst_dims_vec, b_bst_dims_vec) = get_broadcast_dims(uin, bin);
framework::Tensor u_bst(uin.type());
TensorExpand<T, DeviceContext>(dev_ctx, uin, &u_bst, u_bst_dims_vec);
framework::Tensor b_bst(bin.type());
TensorExpand<T, DeviceContext>(dev_ctx, bin, &b_bst, b_bst_dims_vec);
auto &phi_dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE &>(
dev_ctx);
// calculate u's conjugate for complex
framework::Tensor u_conj(u_bst.type());
platform::ForRange<DeviceContext> u_for_range(dev_ctx, u_bst.numel());
phi::funcs::ConjFunctor<T> u_functor(
u_bst.data<T>(), u_bst.numel(),
u_conj.mutable_data<T>(u_bst.dims(), dev_ctx.GetPlace()));
u_for_range(u_functor);
u_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, u_conj);
// calculate b's conjugate for complex
framework::Tensor b_conj(b_bst.type());
platform::ForRange<DeviceContext> b_for_range(dev_ctx, b_bst.numel());
phi::funcs::ConjFunctor<T> b_functor(
b_bst.data<T>(), b_bst.numel(),
b_conj.mutable_data<T>(b_bst.dims(), dev_ctx.GetPlace()));
b_for_range(b_functor);
b_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, b_conj);
auto ut_data = u_conj.mutable_data<T>(dev_ctx.GetPlace());
auto uindims = u_bst.dims();
auto bindims = b_bst.dims();
int uinrank = uindims.size();
int binrank = bindims.size();
int n = uindims[uinrank - 2];
int nrhs = bindims[binrank - 1];
int ldab = std::max(1, n);
// framework::Tensor out_copy(b_conj.type());
// out_copy.Resize(b_conj.dims());
framework::TensorCopy(b_conj, dev_ctx.GetPlace(), out);
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto info_dims = phi::slice_ddim(bindims, 0, binrank - 2);
auto batchsize = product(info_dims);
framework::Tensor tmp;
std::vector<int> tmpdim(1, batchsize);
tmp.Resize(phi::make_ddim(tmpdim));
int *info = tmp.mutable_data<int>(dev_ctx.GetPlace());
CholeskySolveFunctor<DeviceContext, T> functor;
for (int b = 0; b < batchsize; b++) {
auto uin_data_item = &ut_data[b * n * n];
auto out_data_item = &out_data[b * n * nrhs];
auto info_item = &info[b];
functor(dev_ctx, upper, n, nrhs, uin_data_item, ldab, out_data_item,
info_item);
}
// calculate out's conjugate for complex
platform::ForRange<DeviceContext> out_for_range(dev_ctx, out->numel());
phi::funcs::ConjFunctor<T> out_functor(
out->data<T>(), out->numel(),
out->mutable_data<T>(out->dims(), dev_ctx.GetPlace()));
out_for_range(out_functor);
*out = phi::TransposeLast2Dim<T>(phi_dev_ctx, *out);
}
template <typename DeviceContext, typename T>
class CholeskySolveKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *uin = ctx.Input<framework::Tensor>("Y");
auto *bin = ctx.Input<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
auto upper = ctx.Attr<bool>("upper");
cholesky_solve_fn<DeviceContext, T>(ctx, *uin, *bin, out, upper);
}
};
template <typename DeviceContext, typename T>
class CholeskySolveGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *bin = ctx.Input<framework::Tensor>("X");
auto *uin = ctx.Input<framework::Tensor>("Y");
auto *out = ctx.Input<framework::Tensor>("Out");
auto *dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *db = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *du = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto upper = ctx.Attr<bool>("upper");
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
auto &phi_dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE &>(
dev_ctx);
std::vector<int64_t> u_bst_dims_vec;
std::vector<int64_t> b_bst_dims_vec;
std::tie(u_bst_dims_vec, b_bst_dims_vec) = get_broadcast_dims(*uin, *bin);
framework::Tensor u_bst(uin->type());
TensorExpand<T, DeviceContext>(dev_ctx, *uin, &u_bst, u_bst_dims_vec);
framework::Tensor db_bst(bin->type());
TensorExpand<T, DeviceContext>(dev_ctx, *bin, &db_bst, b_bst_dims_vec);
if (dout) {
db->mutable_data<T>(dev_ctx.GetPlace());
cholesky_solve_fn<DeviceContext, T>(ctx, u_bst, *dout, &db_bst, upper);
if (db_bst.dims() == db->dims()) {
framework::TensorCopy(db_bst, dev_ctx.GetPlace(), dev_ctx, db);
} else {
MatrixReduceSumFunctor<DeviceContext, T> functor;
functor(db_bst, db, ctx);
db->Resize(bin->dims());
}
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
// calculate out's conjugate for complex
framework::Tensor out_conj(out->type());
platform::ForRange<DeviceContext> out_for_range(dev_ctx, out->numel());
phi::funcs::ConjFunctor<T> out_functor(
out->data<T>(), out->numel(),
out_conj.mutable_data<T>(out->dims(), dev_ctx.GetPlace()));
out_for_range(out_functor);
out_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, out_conj);
framework::Tensor commonterm(out->type());
auto outdims = out_conj.dims();
auto dbdims = db_bst.dims();
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(outdims, 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(dbdims, 0, false);
auto cmtdim = outdims;
cmtdim[cmtdim.size() - 2] = dbdims[dbdims.size() - 2];
commonterm.Resize(cmtdim);
commonterm.mutable_data<T>(dev_ctx.GetPlace());
blas.MatMul(db_bst, mat_dim_b, out_conj, mat_dim_a, static_cast<T>(1),
&commonterm, static_cast<T>(0));
// calculate commonterm's conjugate for complex
framework::Tensor commonterm_conj(commonterm.type());
platform::ForRange<DeviceContext> commonterm_for_range(
dev_ctx, commonterm.numel());
phi::funcs::ConjFunctor<T> commonterm_functor(
commonterm.data<T>(), commonterm.numel(),
commonterm_conj.mutable_data<T>(commonterm.dims(),
dev_ctx.GetPlace()));
commonterm_for_range(commonterm_functor);
commonterm_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, commonterm_conj);
phi::AddRawKernel<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
commonterm, commonterm_conj, -1, &commonterm);
auto mat_dim_u =
phi::funcs::CreateMatrixDescriptor(u_bst.dims(), 0, false);
auto mat_dim_c =
phi::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false);
Tensor du_bst(uin->type());
// get upper or lower triangular
du_bst.Resize(u_bst.dims());
du_bst.mutable_data<T>(dev_ctx.GetPlace());
if (upper) {
blas.MatMul(u_bst, mat_dim_u, commonterm, mat_dim_c, static_cast<T>(-1),
&du_bst, static_cast<T>(0));
} else {
blas.MatMul(commonterm, mat_dim_c, u_bst, mat_dim_u, static_cast<T>(-1),
&du_bst, static_cast<T>(0));
}
const auto &udims = u_bst.dims();
const auto H = udims[udims.size() - 2];
const auto W = udims[udims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx, u_bst.numel());
TrilTriuCompute<T> tril_triu_computer(du_bst.data<T>(), 0, !upper, H, W,
u_bst.data<T>());
x_for_range(tril_triu_computer);
du->mutable_data<T>(dev_ctx.GetPlace());
if (u_bst.dims() == du->dims()) {
framework::TensorCopy(u_bst, dev_ctx.GetPlace(), dev_ctx, du);
} else {
MatrixReduceSumFunctor<DeviceContext, T> functor;
functor(u_bst, du, ctx);
du->Resize(uin->dims());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -60,45 +60,5 @@ static void triangular_solve(const DeviceContext &context, const Tensor &x,
unitriangular);
}
template <typename DeviceContext, typename T>
class MatrixReduceSumFunctor {
public:
void operator()(const Tensor &input, Tensor *output,
const framework::ExecutionContext &ctx);
};
template <typename T>
class MatrixReduceSumFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const Tensor &in, Tensor *out,
const framework::ExecutionContext &ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = phi::vectorize(in.dims());
auto in_size = in_dims.size();
const std::vector<std::int64_t> out_dims = phi::vectorize(out->dims());
auto out_size = out_dims.size();
std::vector<std::int64_t> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(), out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
out->Resize(phi::make_ddim(out_bst_dims));
std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
ReduceKernelFunctor<platform::CPUDeviceContext, T, SumFunctor>(
&in, out, out_reduce_dims, true, false, ctx)
.template apply<T>();
out->Resize(phi::make_ddim(out_dims));
}
};
} // namespace operators
} // namespace paddle
......@@ -46,8 +46,6 @@ if (WITH_MKLML)
cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml phi_dynload_mklml)
endif()
cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader phi_dynload_lapack)
add_dependencies(dynload_lapack extern_lapack)
# TODO(TJ): add iomp, mkldnn?
if (MKL_FOUND AND WITH_ONEMKL)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/lapack.h"
namespace paddle {
namespace platform {
namespace dynload {
#define DEFINE_WRAP(__name) DynLoad__##__name __name
LAPACK_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <complex>
#include <mutex>
#include "paddle/phi/backends/dynload/lapack.h"
#include "paddle/phi/common/complex.h"
namespace paddle {
namespace platform {
namespace dynload {
/**
* The following macro definition can generate structs
* (for each function) to dynamic load lapack routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \
using DynLoad__##__name = phi::dynload::DynLoad__##__name; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_LAPACK_WRAP(__name) \
DYNAMIC_LOAD_LAPACK_WRAP(__name)
#define LAPACK_ROUTINE_EACH(__macro) \
__macro(dgetrf_); \
__macro(sgetrf_); \
__macro(zheevd_); \
__macro(cheevd_); \
__macro(dsyevd_); \
__macro(ssyevd_); \
__macro(dgeev_); \
__macro(sgeev_); \
__macro(zgeev_); \
__macro(cgeev_); \
__macro(dgels_); \
__macro(sgels_); \
__macro(dgelsd_); \
__macro(sgelsd_); \
__macro(dgelsy_); \
__macro(sgelsy_); \
__macro(dgelss_); \
__macro(sgelss_); \
__macro(zpotrs_); \
__macro(cpotrs_); \
__macro(dpotrs_); \
__macro(spotrs_);
LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP);
#undef DYNAMIC_LOAD_LAPACK_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
// Note(zhouwei): because lapack doesn't provide appropriate header file.
// should expose API statement yourself.
// Because lapack doesn't provide appropriate header file,
// we should expose API statement yourself.
// getrf_(For example)
extern "C" void dgetrf_(
......
......@@ -274,6 +274,60 @@ void HuberLossInferMeta(const MetaTensor& input,
out->share_lod(input);
}
void CholeskySolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims_n = x_dims.size();
auto y_dims_n = y_dims.size();
PADDLE_ENFORCE_GE(x_dims_n,
2,
phi::errors::InvalidArgument(
"the rank of input Y must greater or equal to 2"));
PADDLE_ENFORCE_GE(y_dims_n,
2,
phi::errors::InvalidArgument(
"the rank of input X must greater or equal to 2"));
PADDLE_ENFORCE_EQ(
y_dims[y_dims_n - 1],
y_dims[y_dims_n - 2],
phi::errors::InvalidArgument("input Matrix Y should be square matrix,"
"But Got last shape of %ld x %ld",
y_dims[y_dims_n - 1],
y_dims[y_dims_n - 2]));
PADDLE_ENFORCE_EQ(
x_dims[x_dims_n - 2],
y_dims[y_dims_n - 2],
phi::errors::InvalidArgument("the first dim of Matrix X must be equal to "
"the fisrt dim of Matrix Y,"
"But Got %ld and %ld",
x_dims[x_dims_n - 2],
y_dims[y_dims_n - 2]));
std::vector<int64_t> x_dims_vec = phi::vectorize(x_dims);
std::vector<int64_t> y_dims_vec = phi::vectorize(y_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 2);
std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(), y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
funcs::MatrixGetBroadcastBatchPortion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> x_broadcast_dims({expand_batch_portion});
x_broadcast_dims.insert(x_broadcast_dims.end(),
{x_dims_vec[x_dims_n - 2], x_dims_vec[x_dims_n - 1]});
// dim of 'out' is the same with 'X' after broadcast
out->set_dims(phi::make_ddim(x_broadcast_dims));
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
void TriangularSolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
......
......@@ -62,6 +62,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor* residual,
MetaConfig config = MetaConfig());
void CholeskySolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
MetaTensor* out);
void TriangularSolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CholeskySolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
bool upper,
DenseTensor* dx,
DenseTensor* dy);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CholeskySolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool upper,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(cholesky_solve_grad,
CPU,
ALL_LAYOUT,
phi::CholeskySolveGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace phi {
template <typename T>
class CholeskySolveFunctor<T, CPUContext> {
public:
void operator()(const CPUContext &dev_ctx,
bool upper,
int M,
int N,
T *Adata,
int lda,
T *Bdata,
int *devInfo) {
char uplo = upper ? 'U' : 'L';
funcs::lapackCholeskySolve<T>(uplo, M, N, Adata, lda, Bdata, lda, devInfo);
}
};
} // namespace phi
PD_REGISTER_KERNEL(
cholesky_solve, CPU, ALL_LAYOUT, phi::CholeskySolveKernel, float, double) {}
math_library(lapack_function DEPS dynload_lapack)
math_library(lapack_function DEPS phi_dynload_lapack)
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/fluid/platform/dynload/lapack.h"
#include "paddle/phi/backends/dynload/lapack.h"
#include "paddle/phi/common/complex.h"
namespace phi {
......@@ -22,12 +22,12 @@ namespace funcs {
// LU (for example)
template <>
void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
paddle::platform::dynload::dgetrf_(&m, &n, a, &lda, ipiv, info);
dynload::dgetrf_(&m, &n, a, &lda, ipiv, info);
}
template <>
void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
paddle::platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
}
// eigh
......@@ -47,7 +47,7 @@ void lapackEigh<float>(char jobz,
int *info) {
(void)rwork; // unused
(void)lrwork; // unused
paddle::platform::dynload::ssyevd_(
dynload::ssyevd_(
&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}
......@@ -67,7 +67,7 @@ void lapackEigh<double>(char jobz,
int *info) {
(void)rwork; // unused
(void)lrwork; // unused
paddle::platform::dynload::dsyevd_(
dynload::dsyevd_(
&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}
......@@ -86,20 +86,19 @@ void lapackEigh<phi::dtype::complex<float>, float>(
int *iwork,
int liwork,
int *info) {
paddle::platform::dynload::cheevd_(
&jobz,
&uplo,
&n,
reinterpret_cast<std::complex<float> *>(a),
&lda,
w,
reinterpret_cast<std::complex<float> *>(work),
&lwork,
rwork,
&lrwork,
iwork,
&liwork,
info);
dynload::cheevd_(&jobz,
&uplo,
&n,
reinterpret_cast<std::complex<float> *>(a),
&lda,
w,
reinterpret_cast<std::complex<float> *>(work),
&lwork,
rwork,
&lrwork,
iwork,
&liwork,
info);
}
template <>
......@@ -117,20 +116,19 @@ void lapackEigh<phi::dtype::complex<double>, double>(
int *iwork,
int liwork,
int *info) {
paddle::platform::dynload::zheevd_(
&jobz,
&uplo,
&n,
reinterpret_cast<std::complex<double> *>(a),
&lda,
w,
reinterpret_cast<std::complex<double> *>(work),
&lwork,
rwork,
&lrwork,
iwork,
&liwork,
info);
dynload::zheevd_(&jobz,
&uplo,
&n,
reinterpret_cast<std::complex<double> *>(a),
&lda,
w,
reinterpret_cast<std::complex<double> *>(work),
&lwork,
rwork,
&lrwork,
iwork,
&liwork,
info);
}
// Eig
......@@ -152,20 +150,20 @@ void lapackEig<double>(char jobvl,
double *wr = w;
double *wi = w + n;
(void)rwork; // unused
paddle::platform::dynload::dgeev_(&jobvl,
&jobvr,
&n,
a,
&lda,
wr,
wi,
vl,
&ldvl,
vr,
&ldvr,
work,
&lwork,
info);
dynload::dgeev_(&jobvl,
&jobvr,
&n,
a,
&lda,
wr,
wi,
vl,
&ldvl,
vr,
&ldvr,
work,
&lwork,
info);
}
template <>
......@@ -186,20 +184,20 @@ void lapackEig<float>(char jobvl,
float *wr = w;
float *wi = w + n;
(void)rwork; // unused
paddle::platform::dynload::sgeev_(&jobvl,
&jobvr,
&n,
a,
&lda,
wr,
wi,
vl,
&ldvl,
vr,
&ldvr,
work,
&lwork,
info);
dynload::sgeev_(&jobvl,
&jobvr,
&n,
a,
&lda,
wr,
wi,
vl,
&ldvl,
vr,
&ldvr,
work,
&lwork,
info);
}
template <>
......@@ -218,21 +216,20 @@ void lapackEig<phi::dtype::complex<double>, double>(
int lwork,
double *rwork,
int *info) {
paddle::platform::dynload::zgeev_(
&jobvl,
&jobvr,
&n,
reinterpret_cast<std::complex<double> *>(a),
&lda,
reinterpret_cast<std::complex<double> *>(w),
reinterpret_cast<std::complex<double> *>(vl),
&ldvl,
reinterpret_cast<std::complex<double> *>(vr),
&ldvr,
reinterpret_cast<std::complex<double> *>(work),
&lwork,
rwork,
info);
dynload::zgeev_(&jobvl,
&jobvr,
&n,
reinterpret_cast<std::complex<double> *>(a),
&lda,
reinterpret_cast<std::complex<double> *>(w),
reinterpret_cast<std::complex<double> *>(vl),
&ldvl,
reinterpret_cast<std::complex<double> *>(vr),
&ldvr,
reinterpret_cast<std::complex<double> *>(work),
&lwork,
rwork,
info);
}
template <>
......@@ -251,21 +248,20 @@ void lapackEig<phi::dtype::complex<float>, float>(
int lwork,
float *rwork,
int *info) {
paddle::platform::dynload::cgeev_(
&jobvl,
&jobvr,
&n,
reinterpret_cast<std::complex<float> *>(a),
&lda,
reinterpret_cast<std::complex<float> *>(w),
reinterpret_cast<std::complex<float> *>(vl),
&ldvl,
reinterpret_cast<std::complex<float> *>(vr),
&ldvr,
reinterpret_cast<std::complex<float> *>(work),
&lwork,
rwork,
info);
dynload::cgeev_(&jobvl,
&jobvr,
&n,
reinterpret_cast<std::complex<float> *>(a),
&lda,
reinterpret_cast<std::complex<float> *>(w),
reinterpret_cast<std::complex<float> *>(vl),
&ldvl,
reinterpret_cast<std::complex<float> *>(vr),
&ldvr,
reinterpret_cast<std::complex<float> *>(work),
&lwork,
rwork,
info);
}
template <>
......@@ -280,8 +276,7 @@ void lapackGels<double>(char trans,
double *work,
int lwork,
int *info) {
paddle::platform::dynload::dgels_(
&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
dynload::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
}
template <>
......@@ -296,8 +291,7 @@ void lapackGels<float>(char trans,
float *work,
int lwork,
int *info) {
paddle::platform::dynload::sgels_(
&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
dynload::sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
}
template <>
......@@ -316,20 +310,20 @@ void lapackGelsd<double>(int m,
double *rwork,
int *iwork,
int *info) {
paddle::platform::dynload::dgelsd_(&m,
&n,
&nrhs,
a,
&lda,
b,
&ldb,
s,
&rcond,
rank,
work,
&lwork,
iwork,
info);
dynload::dgelsd_(&m,
&n,
&nrhs,
a,
&lda,
b,
&ldb,
s,
&rcond,
rank,
work,
&lwork,
iwork,
info);
}
template <>
......@@ -348,20 +342,20 @@ void lapackGelsd<float>(int m,
float *rwork,
int *iwork,
int *info) {
paddle::platform::dynload::sgelsd_(&m,
&n,
&nrhs,
a,
&lda,
b,
&ldb,
s,
&rcond,
rank,
work,
&lwork,
iwork,
info);
dynload::sgelsd_(&m,
&n,
&nrhs,
a,
&lda,
b,
&ldb,
s,
&rcond,
rank,
work,
&lwork,
iwork,
info);
}
template <>
......@@ -379,7 +373,7 @@ void lapackGelsy<double>(int m,
int lwork,
double *rwork,
int *info) {
paddle::platform::dynload::dgelsy_(
dynload::dgelsy_(
&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info);
}
......@@ -398,7 +392,7 @@ void lapackGelsy<float>(int m,
int lwork,
float *rwork,
int *info) {
paddle::platform::dynload::sgelsy_(
dynload::sgelsy_(
&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info);
}
......@@ -417,7 +411,7 @@ void lapackGelss<double>(int m,
int lwork,
double *rwork,
int *info) {
paddle::platform::dynload::dgelss_(
dynload::dgelss_(
&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info);
}
......@@ -436,7 +430,7 @@ void lapackGelss<float>(int m,
int lwork,
float *rwork,
int *info) {
paddle::platform::dynload::sgelss_(
dynload::sgelss_(
&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info);
}
......@@ -450,15 +444,14 @@ void lapackCholeskySolve<phi::dtype::complex<double>>(
phi::dtype::complex<double> *b,
int ldb,
int *info) {
paddle::platform::dynload::zpotrs_(
&uplo,
&n,
&nrhs,
reinterpret_cast<std::complex<double> *>(a),
&lda,
reinterpret_cast<std::complex<double> *>(b),
&ldb,
info);
dynload::zpotrs_(&uplo,
&n,
&nrhs,
reinterpret_cast<std::complex<double> *>(a),
&lda,
reinterpret_cast<std::complex<double> *>(b),
&ldb,
info);
}
template <>
......@@ -471,14 +464,14 @@ void lapackCholeskySolve<phi::dtype::complex<float>>(
phi::dtype::complex<float> *b,
int ldb,
int *info) {
paddle::platform::dynload::cpotrs_(&uplo,
&n,
&nrhs,
reinterpret_cast<std::complex<float> *>(a),
&lda,
reinterpret_cast<std::complex<float> *>(b),
&ldb,
info);
dynload::cpotrs_(&uplo,
&n,
&nrhs,
reinterpret_cast<std::complex<float> *>(a),
&lda,
reinterpret_cast<std::complex<float> *>(b),
&ldb,
info);
}
template <>
......@@ -490,7 +483,7 @@ void lapackCholeskySolve<double>(char uplo,
double *b,
int ldb,
int *info) {
paddle::platform::dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
}
template <>
......@@ -502,7 +495,7 @@ void lapackCholeskySolve<float>(char uplo,
float *b,
int ldb,
int *info) {
paddle::platform::dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
}
} // namespace funcs
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// backward reuse forward, HIP not support forward
#include "paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(cholesky_solve_grad, // cuda_only
GPU,
ALL_LAYOUT,
phi::CholeskySolveGradKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace phi {
template <typename T>
void cusolver_potrs(const solverHandle_t &handle,
cublasFillMode_t uplo,
int M,
int N,
T *Adata,
int lda,
T *Bdata,
int ldb,
int *devInfo);
template <>
void cusolver_potrs<float>(const solverHandle_t &handle,
cublasFillMode_t uplo,
int M,
int N,
float *Adata,
int lda,
float *Bdata,
int ldb,
int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSpotrs(
handle, uplo, M, N, Adata, lda, Bdata, ldb, devInfo));
}
template <>
void cusolver_potrs<double>(const solverHandle_t &handle,
cublasFillMode_t uplo,
int M,
int N,
double *Adata,
int lda,
double *Bdata,
int ldb,
int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDpotrs(
handle, uplo, M, N, Adata, lda, Bdata, ldb, devInfo));
}
template <>
void cusolver_potrs<phi::dtype::complex<float>>(
const solverHandle_t &handle,
cublasFillMode_t uplo,
int M,
int N,
phi::dtype::complex<float> *Adata,
int lda,
phi::dtype::complex<float> *Bdata,
int ldb,
int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCpotrs(handle,
uplo,
M,
N,
reinterpret_cast<const cuComplex *>(Adata),
lda,
reinterpret_cast<cuComplex *>(Bdata),
ldb,
devInfo));
}
template <>
void cusolver_potrs<phi::dtype::complex<double>>(
const cusolverDnHandle_t &handle,
cublasFillMode_t uplo,
int M,
int N,
phi::dtype::complex<double> *Adata,
int lda,
phi::dtype::complex<double> *Bdata,
int ldb,
int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnZpotrs(
handle,
uplo,
M,
N,
reinterpret_cast<const cuDoubleComplex *>(Adata),
lda,
reinterpret_cast<cuDoubleComplex *>(Bdata),
ldb,
devInfo));
}
template <typename T>
class CholeskySolveFunctor<T, GPUContext> {
public:
void operator()(const GPUContext &dev_ctx,
bool upper,
int M,
int N,
T *Adata,
int lda,
T *Bdata,
int *devInfo) {
cublasFillMode_t uplo =
upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
auto handle = dev_ctx.cusolver_dn_handle();
cusolver_potrs<T>(handle, uplo, M, N, Adata, lda, Bdata, lda, devInfo);
}
};
} // namespace phi
PD_REGISTER_KERNEL(cholesky_solve, // cuda_only
GPU,
ALL_LAYOUT,
phi::CholeskySolveKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/cholesky_solve_grad_kernel.h"
#include "paddle/phi/kernels/cholesky_solve_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/tril_triu_op.h"
namespace phi {
template <typename T, typename Context>
void CholeskySolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
bool upper,
DenseTensor* dx,
DenseTensor* dy) {
// get broadcast dim
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) =
funcs::MatrixGetBroadcastDims(x, y);
ScalarArray x_bst_dims(x_bst_dims_vec);
ScalarArray y_bst_dims(y_bst_dims_vec);
// Tensor broadcast to temp 'y_bst'
DenseTensor y_bst = phi::Empty<T, Context>(dev_ctx, y_bst_dims);
ExpandKernel<T, Context>(dev_ctx, y, y_bst_dims, &y_bst);
// reuse forward to calculate dx_bst, which is broad_cast of dx
DenseTensor dx_bst = phi::Empty<T, Context>(dev_ctx, x_bst_dims);
CholeskySolveKernel<T, Context>(dev_ctx, dout, y_bst, upper, &dx_bst);
// get 'dx' according to 'dx_bst'
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
if (dx_bst.dims() == x.dims()) {
Copy<Context>(dev_ctx, dx_bst, dev_ctx.GetPlace(), false, dx);
} else {
funcs::MatrixReduceSumFunctor<T, Context> functor;
functor(dev_ctx, dx_bst, dx);
dx->Resize(x.dims());
}
// calculate out's conjugate for complex
DenseTensor out_conj = Conj<T, Context>(dev_ctx, out);
out_conj = phi::TransposeLast2Dim<T>(dev_ctx, out_conj);
DenseTensor commonterm = phi::Empty<T, Context>(dev_ctx, y_bst_dims);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
blas.MatMul(dx_bst,
phi::funcs::CreateMatrixDescriptor(dx_bst.dims(), 0, false),
out_conj,
phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false),
static_cast<T>(1),
&commonterm,
static_cast<T>(0));
// calculate commonterm's conjugate for complex
DenseTensor commonterm_conj = Conj<T, Context>(dev_ctx, commonterm);
commonterm_conj = phi::TransposeLast2Dim<T>(dev_ctx, commonterm_conj);
phi::AddRawKernel<T>(dev_ctx, commonterm, commonterm_conj, -1, &commonterm);
DenseTensor dy_bst = phi::Empty<T, Context>(dev_ctx, y_bst_dims);
if (upper) {
blas.MatMul(y_bst,
phi::funcs::CreateMatrixDescriptor(y_bst.dims(), 0, false),
commonterm,
phi::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false),
static_cast<T>(-1),
&dy_bst,
static_cast<T>(0));
} else {
blas.MatMul(commonterm,
phi::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false),
y_bst,
phi::funcs::CreateMatrixDescriptor(y_bst.dims(), 0, false),
static_cast<T>(-1),
&dy_bst,
static_cast<T>(0));
}
// get upper or lower of 'dy_bst'
DenseTensor dy_bst_upper = phi::Empty<T, Context>(dev_ctx, y_bst_dims);
int y_bst_ndim = y_bst_dims_vec.size();
const auto H = y_bst_dims_vec[y_bst_ndim - 2];
const auto W = y_bst_dims_vec[y_bst_ndim - 1];
phi::funcs::ForRange<Context> y_for_range(dev_ctx, dy_bst.numel());
paddle::operators::TrilTriuCompute<T> tril_triu_functor(
dy_bst.data<T>(), 0, !upper, H, W, dy_bst_upper.data<T>());
y_for_range(tril_triu_functor);
// get 'dy' according to 'dy_bst'
dy->Resize(y.dims());
dev_ctx.template Alloc<T>(dy);
if (dy_bst_upper.dims() == y.dims()) {
Copy<Context>(dev_ctx, dy_bst_upper, dev_ctx.GetPlace(), false, dy);
} else {
funcs::MatrixReduceSumFunctor<T, Context> functor;
functor(dev_ctx, dy_bst_upper, dy);
dy->Resize(y.dims());
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/cholesky_solve_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
template <typename T, typename Context>
class CholeskySolveFunctor {
public:
void operator()(const Context& dev_ctx,
bool upper,
int M,
int N,
T* Adata,
int lda,
T* Bdata,
int* devInfo);
};
template <typename T, typename Context>
void CholeskySolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool upper,
DenseTensor* out) {
// get broadcast dim
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) =
funcs::MatrixGetBroadcastDims(x, y);
ScalarArray x_bst_dims(x_bst_dims_vec);
ScalarArray y_bst_dims(y_bst_dims_vec);
DenseTensor y_bst = phi::Empty<T, Context>(dev_ctx, y_bst_dims);
ExpandKernel<T, Context>(dev_ctx, y, y_bst_dims, &y_bst);
// Tensor broadcast to temp 'x_bst' and 'y_bst'
DenseTensor x_bst = phi::Empty<T, Context>(dev_ctx, x_bst_dims);
ExpandKernel<T, Context>(dev_ctx, x, x_bst_dims, &x_bst);
// calculate y_bst's conjugate for complex
DenseTensor y_bst_conj = Conj<T, Context>(dev_ctx, y_bst);
y_bst_conj = phi::TransposeLast2Dim<T>(dev_ctx, y_bst_conj);
T* y_bst_conj_data = y_bst_conj.data<T>();
// calculate x_bst's conjugate for complex
DenseTensor x_bst_conj = Conj<T, Context>(dev_ctx, x_bst);
x_bst_conj = phi::TransposeLast2Dim<T>(dev_ctx, x_bst_conj);
// copy x_bst's conjugate to 'result'
DenseTensor result;
Copy<Context>(dev_ctx, x_bst_conj, dev_ctx.GetPlace(), false, &result);
T* res_data = result.data<T>();
// CPU use lapack, GPU use cusolver
int x_bst_ndim = x_bst_dims_vec.size();
int M = static_cast<int>(x_bst_dims_vec[x_bst_ndim - 2]);
int N = static_cast<int>(x_bst_dims_vec[x_bst_ndim - 1]);
int batchsize = product(phi::slice_ddim(x_bst.dims(), 0, x_bst_ndim - 2));
DenseTensor info =
phi::Empty<int, Context>(dev_ctx, ScalarArray({batchsize}));
int* info_data = info.data<int>();
CholeskySolveFunctor<T, Context> functor;
for (int i = 0; i < batchsize; ++i) {
functor(dev_ctx,
upper,
M,
N,
y_bst_conj_data + i * M * M,
std::max(1, M),
res_data + i * M * N,
info_data + i);
}
// calculate out's conjugate for complex
result = phi::TransposeLast2Dim<T>(dev_ctx, result);
out->Resize(phi::make_ddim(x_bst_dims_vec));
ConjKernel<T, Context>(dev_ctx, result, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CholeskySolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_solve_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"upper"},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(cholesky_solve_grad,
phi::CholeskySolveGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册