未验证 提交 2cf2e786 编写于 作者: Y Yulong Ao 提交者: GitHub

[Phi] Move QR to Phi (#44742)

* [Phi] Move Qr to the Phi

* [Phi] Regiter the cpu grad kernel for qr

* [Phi] Share the cuda kernels to lstsq

* [Phi] Remove some improper inlcude files

* [Phi] Modify codes based on the reviews

* [Phi] Remove unecessary files and add the cuda_only comment

* [Phi] Remove the unecessary include file

* [Phi] Remove qr_op.cu and lstsq_op.cu
上级 a2980169
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include <string>
#include <vector>
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/lstsq_op.h"
#include "paddle/fluid/operators/qr_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
template <typename DeviceContext, typename T>
class LstsqCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
auto* solution = context.Output<Tensor>("Solution");
auto dito =
math::DeviceIndependenceTensorOperations<phi::GPUContext, T>(context);
auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto x_dims = x.dims();
auto y_dims = y.dims();
int dim_size = x_dims.size();
int m = x_dims[dim_size - 2];
int n = x_dims[dim_size - 1];
int nrhs = y_dims[dim_size - 1];
int min_mn = std::min(m, n);
int max_mn = std::max(m, n);
int k = min_mn;
int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y);
int tau_stride = min_mn;
int batch_count = BatchCount(x);
Tensor new_x, new_y;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
new_y.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), &new_y);
// Prepare tau
auto tau_dims_vec = phi::vectorize<int>(x_dims);
tau_dims_vec.pop_back();
tau_dims_vec[tau_dims_vec.size() - 1] = min_mn;
Tensor tau = dito.Fill(tau_dims_vec, 0);
auto tau_data = tau.mutable_data<T>(context.GetPlace());
using Context =
typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
auto& phi_dev_ctx = static_cast<const Context&>(dev_ctx);
if (m >= n) {
Tensor tmp_x = dito.Transpose(new_x);
Tensor tmp_y = dito.Transpose(new_y);
auto x_data = tmp_x.mutable_data<T>(context.GetPlace());
auto y_data = tmp_y.mutable_data<T>(context.GetPlace());
// step 1, compute QR factorization using geqrf
BatchedGeqrf<DeviceContext, T>(dev_ctx,
batch_count,
m,
n,
x_data,
m,
tau_data,
x_stride,
tau_stride);
// Step 2, Y <- Q^H Y
BatchedOrmqr<DeviceContext, T>(dev_ctx,
true,
true,
batch_count,
m,
nrhs,
k,
x_data,
x_stride,
tau_data,
tau_stride,
y_data,
y_stride);
Tensor trans_r = dito.Transpose(tmp_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);
Tensor trans_y = dito.Transpose(tmp_y);
Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn});
// Step 3, solve R X = Y
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, res_r, slice_y, true, false, false, solution);
} else {
auto x_data = new_x.mutable_data<T>(context.GetPlace());
auto y_data = new_y.mutable_data<T>(context.GetPlace());
// step 1, compute QR factorization using geqrf
BatchedGeqrf<DeviceContext, T>(dev_ctx,
batch_count,
n,
m,
x_data,
n,
tau_data,
x_stride,
tau_stride);
// Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, res_r, new_y, true, true, false, solution);
// Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx,
batch_count,
n,
m,
min_mn,
x_data,
n,
tau_data,
x_stride,
tau_stride);
Tensor trans_q = dito.Transpose(new_x);
Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m});
Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false);
framework::TensorCopy(solu_tensor, context.GetPlace(), solution);
}
}
};
template <>
void BatchedOrmqr<phi::GPUContext, float>(const phi::GPUContext& dev_ctx,
bool left,
bool transpose,
int batch_size,
int m,
int n,
int k,
float* a,
int a_stride,
float* tau,
int tau_stride,
float* other,
int other_stride) {
int lwork = 0;
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
int lda = std::max<int>(1, left ? m : n);
int ldc = std::max<int>(1, m);
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
float* other_working_ptr = &other[i * other_stride];
handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnSormqr(handle,
side,
trans,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
other_working_ptr,
ldc,
workspace_ptr,
lwork,
info_d));
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver info is not zero but [%d]", i, info_h));
}
}
template <>
void BatchedOrmqr<phi::GPUContext, double>(const phi::GPUContext& dev_ctx,
bool left,
bool transpose,
int batch_size,
int m,
int n,
int k,
double* a,
int a_stride,
double* tau,
int tau_stride,
double* other,
int other_stride) {
int lwork = 0;
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
int lda = std::max<int>(1, left ? m : n);
int ldc = std::max<int>(1, m);
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
double* other_working_ptr = &other[i * other_stride];
handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDormqr(handle,
side,
trans,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
other_working_ptr,
ldc,
workspace_ptr,
lwork,
info_d));
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver info is not zero but [%d]", i, info_h));
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(lstsq,
ops::LstsqCUDAKernel<phi::GPUContext, float>,
ops::LstsqCUDAKernel<phi::GPUContext, double>);
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#define EPSILON 1e-6
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss };
using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim) {
auto x_vec = vectorize(x_dim);
return phi::make_ddim(x_vec);
}
template <typename DeviceContext, typename T>
class LstsqCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using ValueType = phi::dtype::Real<T>;
const Tensor& x = *context.Input<Tensor>("X");
auto y = context.Input<Tensor>("Y");
auto rcond = context.Attr<float>("rcond");
auto driver_string = context.Attr<std::string>("driver");
static auto driver_type = std::unordered_map<std::string, LapackDriverType>(
{{"gels", LapackDriverType::Gels},
{"gelsy", LapackDriverType::Gelsy},
{"gelsd", LapackDriverType::Gelsd},
{"gelss", LapackDriverType::Gelss}});
auto driver = driver_type[driver_string];
auto solution = context.Output<Tensor>("Solution");
auto* rank = context.Output<Tensor>("Rank");
auto* singular_values = context.Output<Tensor>("SingularValues");
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T>(context);
auto x_dims = x.dims();
auto y_dims = y->dims();
int dim_size = x_dims.size();
int x_stride = MatrixStride(x);
int y_stride = MatrixStride(*y);
int batch_count = BatchCount(x);
auto solution_dim = solution->dims();
int ori_solu_stride = MatrixStride(*solution);
int max_solu_stride = std::max(y_stride, ori_solu_stride);
int min_solu_stride = std::min(y_stride, ori_solu_stride);
// lapack is a column-major storge, transpose make the input to
// have a continuous memory layout
int info = 0;
int m = x_dims[dim_size - 2];
int n = x_dims[dim_size - 1];
int nrhs = y_dims[dim_size - 1];
int lda = std::max<int>(m, 1);
int ldb = std::max<int>(1, std::max(m, n));
Tensor new_x;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
solution->mutable_data<T>(
context.GetPlace(),
size_t(batch_count * std::max(m, n) * nrhs * sizeof(T)));
if (m >= n) {
const Tensor& new_y = *context.Input<Tensor>("Y");
framework::TensorCopy(new_y, context.GetPlace(), solution);
} else {
auto* solu_data = solution->data<T>();
auto* y_data = y->data<T>();
for (auto i = 0; i < batch_count; i++) {
for (auto j = 0; j < min_solu_stride; j++) {
solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j];
}
}
}
Tensor input_x_trans = dito.Transpose(new_x);
Tensor input_y_trans = dito.Transpose(*solution);
framework::TensorCopy(input_x_trans, context.GetPlace(), &new_x);
framework::TensorCopy(input_y_trans, context.GetPlace(), solution);
auto* x_vector = new_x.data<T>();
auto* y_vector = solution->data<T>();
// "gels" divers does not need to compute rank
int rank_32 = 0;
int* rank_data = nullptr;
int* rank_working_ptr = nullptr;
if (driver != LapackDriverType::Gels) {
rank_data = rank->mutable_data<int>(context.GetPlace());
rank_working_ptr = rank_data;
}
// "gelsd" and "gelss" divers need to compute singular values
ValueType* s_data = nullptr;
ValueType* s_working_ptr = nullptr;
int s_stride = 0;
if (driver == LapackDriverType::Gelsd ||
driver == LapackDriverType::Gelss) {
s_data = singular_values->mutable_data<ValueType>(context.GetPlace());
s_working_ptr = s_data;
auto s_dims = singular_values->dims();
s_stride = s_dims[s_dims.size() - 1];
}
// "jpvt" is only used for "gelsy" driver
Tensor jpvt;
int* jpvt_data = nullptr;
if (driver == LapackDriverType::Gelsy) {
jpvt.Resize(phi::make_ddim({std::max<int>(1, n)}));
jpvt_data = jpvt.mutable_data<int>(context.GetPlace());
}
// run once the driver, first to get the optimal workspace size
int lwork = -1;
T wkopt;
ValueType rwkopt;
int iwkopt = 0;
if (driver == LapackDriverType::Gels) {
phi::funcs::lapackGels(
'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info);
} else if (driver == LapackDriverType::Gelsd) {
phi::funcs::lapackGelsd(m,
n,
nrhs,
x_vector,
lda,
y_vector,
ldb,
s_working_ptr,
static_cast<ValueType>(rcond),
&rank_32,
&wkopt,
lwork,
&rwkopt,
&iwkopt,
&info);
} else if (driver == LapackDriverType::Gelsy) {
phi::funcs::lapackGelsy(m,
n,
nrhs,
x_vector,
lda,
y_vector,
ldb,
jpvt_data,
static_cast<ValueType>(rcond),
&rank_32,
&wkopt,
lwork,
&rwkopt,
&info);
} else if (driver == LapackDriverType::Gelss) {
phi::funcs::lapackGelss(m,
n,
nrhs,
x_vector,
lda,
y_vector,
ldb,
s_working_ptr,
static_cast<ValueType>(rcond),
&rank_32,
&wkopt,
lwork,
&rwkopt,
&info);
}
lwork = std::max<int>(1, static_cast<int>(phi::dtype::Real<T>(wkopt)));
Tensor work;
work.Resize(phi::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());
// "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers
Tensor rwork;
ValueType* rwork_data = nullptr;
if (framework::IsComplexType(framework::TransToProtoVarType(x.dtype())) &&
driver != LapackDriverType::Gels) {
int rwork_len = 0;
if (driver == LapackDriverType::Gelsy) {
rwork_len = std::max<int>(1, 2 * n);
} else if (driver == LapackDriverType::Gelss) {
rwork_len = std::max<int>(1, 5 * std::min(m, n));
} else if (driver == LapackDriverType::Gelsd) {
rwork_len = std::max<int>(1, rwkopt);
}
rwork.Resize(phi::make_ddim({rwork_len}));
rwork_data = rwork.mutable_data<ValueType>(context.GetPlace());
}
// "iwork" workspace array is relavant only for "gelsd" driver
Tensor iwork;
int* iwork_data = nullptr;
if (driver == LapackDriverType::Gelsd) {
iwork.Resize(phi::make_ddim({std::max<int>(1, iwkopt)}));
iwork_data = iwork.mutable_data<int>(context.GetPlace());
}
for (auto i = 0; i < batch_count; ++i) {
auto* x_input = &x_vector[i * x_stride];
auto* y_input = &y_vector[i * max_solu_stride];
rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr;
s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr;
if (driver == LapackDriverType::Gels) {
phi::funcs::lapackGels('N',
m,
n,
nrhs,
x_input,
lda,
y_input,
ldb,
work_data,
lwork,
&info);
} else if (driver == LapackDriverType::Gelsd) {
phi::funcs::lapackGelsd(m,
n,
nrhs,
x_input,
lda,
y_input,
ldb,
s_working_ptr,
static_cast<ValueType>(rcond),
&rank_32,
work_data,
lwork,
rwork_data,
iwork_data,
&info);
} else if (driver == LapackDriverType::Gelsy) {
phi::funcs::lapackGelsy(m,
n,
nrhs,
x_input,
lda,
y_input,
ldb,
jpvt_data,
static_cast<ValueType>(rcond),
&rank_32,
work_data,
lwork,
rwork_data,
&info);
} else if (driver == LapackDriverType::Gelss) {
phi::funcs::lapackGelss(m,
n,
nrhs,
x_input,
lda,
y_input,
ldb,
s_working_ptr,
static_cast<ValueType>(rcond),
&rank_32,
work_data,
lwork,
rwork_data,
&info);
}
PADDLE_ENFORCE_EQ(
info,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: Lapack info is not zero but [%d]", i, info));
if (rank_working_ptr) *rank_working_ptr = static_cast<int>(rank_32);
}
Tensor tmp_s = dito.Transpose(*solution);
framework::TensorCopy(tmp_s, context.GetPlace(), solution);
if (m > n) {
auto* solu_data = solution->data<T>();
for (auto i = 1; i < batch_count; i++) {
for (auto j = 0; j < min_solu_stride; j++) {
solu_data[i * min_solu_stride + j] =
solu_data[i * max_solu_stride + j];
}
}
}
solution->Resize(UDDim(solution_dim));
}
};
template <typename DeviceContext, typename T>
void BatchedOrmqr(const DeviceContext& dev_ctx,
bool left,
bool transpose,
int batch_size,
int m,
int n,
int k,
T* a,
int a_stride,
T* tau,
int tau_stride,
T* other,
int other_stride);
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/qr_op.h"
#include <memory>
#include <string>
#include <unordered_map>
......@@ -123,7 +121,3 @@ REGISTER_OPERATOR(qr,
QrInferShapeFunctor);
REGISTER_OPERATOR(qr_grad, ops::QrGradOp);
REGISTER_OP_CPU_KERNEL(qr_grad,
ops::QrGradKernel<phi::CPUContext, float>,
ops::QrGradKernel<phi::CPUContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include <thrust/device_vector.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/qr_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
// Reuse some helper functions from svd
#include "paddle/fluid/operators/svd_helper.h"
namespace paddle {
namespace operators {
template <typename T>
class QrGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool compute_q;
bool reduced_mode;
auto& dev_ctx = context.template device_context<phi::GPUContext>();
const Tensor& x = *context.Input<Tensor>("X");
Tensor& q = *context.Output<Tensor>("Q");
Tensor& r = *context.Output<Tensor>("R");
const std::string mode = context.Attr<std::string>("mode");
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel,
0,
platform::errors::PreconditionNotMet("The input of QR is empty."));
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int batch_size = numel / (m * n);
int qr_stride = m * n;
int tau_stride = min_mn;
if (compute_q) {
q.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
}
r.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
auto dito =
math::DeviceIndependenceTensorOperations<phi::GPUContext, T>(context);
// Note: allocate temporary tensors because of lacking in-place operatios.
// Prepare qr
Tensor qr;
qr.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * n * sizeof(phi::dtype::Real<T>)));
// BatchedGeqrf performs computation in-place and 'qr' must be a copy of
// input
paddle::framework::TensorCopy(x, context.GetPlace(), &qr);
// Prepare tau
auto tau_dims_vec = phi::vectorize<int>(x_dims);
tau_dims_vec.pop_back();
tau_dims_vec[tau_dims_vec.size() - 1] = min_mn;
Tensor tau = dito.Fill(tau_dims_vec, 0);
// Transpose 'qr' to conform the column-major order
auto tmp_qr = dito.Transpose(qr);
framework::TensorCopy(tmp_qr, qr.place(), &qr);
auto qr_data = qr.mutable_data<T>(context.GetPlace());
auto tau_data = tau.mutable_data<T>(context.GetPlace());
BatchedGeqrf<phi::GPUContext, T>(
dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride);
if (reduced_mode) {
auto trans_qr = dito.Transpose(qr);
auto sliced_qr = dito.Slice(trans_qr, {-2}, {0}, {min_mn});
auto tmp_r = dito.TrilTriu(sliced_qr, 0, false);
// Transpose 'tmp_r' to retore the original row-major order
framework::TensorCopy(tmp_r, r.place(), &r);
} else {
auto trans_qr = dito.Transpose(qr);
auto tmp_r = dito.TrilTriu(trans_qr, 0, false);
// Transpose 'tmp_r' to retore the original row-major order
framework::TensorCopy(tmp_r, r.place(), &r);
}
if (compute_q) {
// Perform QRGQR for Q using the result from GEQRF
// Transpose 'q' to retore the original row-major order
if (reduced_mode) {
BatchedOrgqr<phi::GPUContext, T>(dev_ctx,
batch_size,
m,
min_mn,
min_mn,
qr_data,
m,
tau_data,
qr_stride,
tau_stride);
auto trans_q = dito.Transpose(qr);
auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn});
framework::TensorCopy(sliced_q, q.place(), &q);
} else {
if (m > n) {
auto new_qr_dims_vec = phi::vectorize<int>(x_dims);
new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m;
Tensor new_qr = dito.Fill(new_qr_dims_vec, 0);
auto new_qr_data = new_qr.mutable_data<T>(context.GetPlace());
auto new_qr_stride = m * m;
for (int i = 0; i < batch_size; ++i) {
memory::Copy(dev_ctx.GetPlace(),
(new_qr_data + i * new_qr_stride),
dev_ctx.GetPlace(),
(qr_data + i * qr_stride),
qr_stride * sizeof(phi::dtype::Real<T>),
dev_ctx.stream());
}
BatchedOrgqr<phi::GPUContext, T>(dev_ctx,
batch_size,
m,
m,
min_mn,
new_qr_data,
m,
tau_data,
new_qr_stride,
tau_stride);
auto trans_q = dito.Transpose(new_qr);
framework::TensorCopy(trans_q, q.place(), &q);
} else {
BatchedOrgqr<phi::GPUContext, T>(dev_ctx,
batch_size,
m,
m,
min_mn,
qr_data,
m,
tau_data,
qr_stride,
tau_stride);
auto trans_q = dito.Transpose(qr);
auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m});
framework::TensorCopy(sliced_q, q.place(), &q);
}
}
}
}
};
template <>
void BatchedGeqrf<phi::GPUContext, float>(const phi::GPUContext& dev_ctx,
int batch_size,
int m,
int n,
float* a,
int lda,
float* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgeqrf_bufferSize(
handle, m, n, a, lda, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnSgeqrf(handle,
m,
n,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedGeqrf<phi::GPUContext, double>(const phi::GPUContext& dev_ctx,
int batch_size,
int m,
int n,
double* a,
int lda,
double* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgeqrf_bufferSize(
handle, m, n, a, lda, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDgeqrf(handle,
m,
n,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedOrgqr<phi::GPUContext, float>(const phi::GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
float* a,
int lda,
float* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnSorgqr(handle,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedOrgqr<phi::GPUContext, double>(const phi::GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
double* a,
int lda,
double* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDorgqr(handle,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h));
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(qr, ops::QrGPUKernel<float>, ops::QrGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(qr_grad,
ops::QrGradKernel<phi::GPUContext, float>,
ops::QrGradKernel<phi::GPUContext, double>);
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Dense>
#include <cstdarg>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
static inline std::tuple<bool, bool> _parse_qr_mode(std::string mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
template <typename DeviceContext, typename T>
class QrGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor& Q = *ctx.Input<framework::Tensor>("Q");
const framework::Tensor& R = *ctx.Input<framework::Tensor>("R");
// Use a different name A instead of X
const framework::Tensor& A = *ctx.Input<framework::Tensor>("X");
const framework::Tensor& dQ =
*ctx.Input<framework::Tensor>(framework::GradVarName("Q"));
const framework::Tensor& dR =
*ctx.Input<framework::Tensor>(framework::GradVarName("R"));
// Use a different name dA instead of dX
framework::Tensor& dA =
*ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dA.mutable_data<phi::dtype::Real<T>>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T>()(dev_ctx, &dA, T(0));
auto dito = math::DeviceIndependenceTensorOperations<DeviceContext, T>(ctx);
std::string mode = ctx.Attr<std::string>("mode");
bool compute_q, reduced;
std::tie(compute_q, reduced) = _parse_qr_mode(mode);
if (!compute_q) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The derivative of qr is not implemented when mode='r'."));
}
auto a_dims = A.dims();
int a_rank = a_dims.size();
int m = a_dims[a_rank - 2];
int n = a_dims[a_rank - 1];
if ((m > n) && (!reduced)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The derivative of qr is not implemented when mode='complete' and "
"nrows > ncols."));
}
// m >= n case
auto m_gt_n_case =
[](const framework::ExecutionContext& ctx,
math::DeviceIndependenceTensorOperations<DeviceContext, T>& dito,
const Tensor& dQ,
const Tensor& dR,
const Tensor& A,
const Tensor& Q,
const Tensor& R) -> framework::Tensor {
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable
// Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// dR^H
framework::Tensor R_term;
if (ctx.HasInput(framework::GradVarName("R"))) {
R_term = dito.Matmul(R, dito.Transpose(dR));
} else {
R_term = dito.Fill(phi::vectorize<int>(R.dims()), 0);
}
// dQ^H * Q
framework::Tensor Q_term;
if (ctx.HasInput(framework::GradVarName("Q"))) {
Q_term = dito.Matmul(dito.Transpose(dQ), Q);
} else {
Q_term = dito.Fill(phi::vectorize<int>(R.dims()), 0);
}
framework::Tensor M_tmp1 = dito.Sub(R_term, Q_term);
// Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity
framework::Tensor M_tril_0 = dito.TrilTriu(M_tmp1, 0, true);
framework::Tensor M_tril_1 = dito.TrilTriu(M_tmp1, -1, true);
framework::Tensor M = dito.Add(M_tril_0, dito.Transpose(M_tril_1));
framework::Tensor rhs_term;
if (ctx.HasInput(framework::GradVarName("Q"))) {
rhs_term = dito.Add(dQ, dito.Matmul(Q, M));
} else {
rhs_term = dito.Matmul(Q, M);
}
// dA * R^H = rhs_term
auto dA =
dito.TriangularSolve(dito.Transpose(dito.Conj(dito.Transpose(R))),
dito.Transpose(rhs_term),
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);
return dito.Transpose(dA);
};
if (m >= n) {
auto dA_tmp = m_gt_n_case(ctx, dito, dQ, dR, A, Q, R);
framework::TensorCopy(dA_tmp, dA.place(), &dA);
} else {
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
dA.mutable_data<phi::dtype::Real<T>>(ctx.GetPlace());
auto Y = dito.Slice(A, {-1}, {m}, {n});
auto U = dito.Slice(R, {-1}, {0}, {m});
framework::Tensor dY, dX, dV, dR_tmp, dQ_prime;
if (ctx.HasInput(framework::GradVarName("R"))) {
dV = dito.Slice(dR, {-1}, {m}, {n});
dR_tmp = dito.Slice(dR, {-1}, {0}, {m});
// Y * dV^H
dQ_prime = dito.Matmul(Y, dito.Transpose(dV));
} else {
dV = dito.Fill(phi::vectorize<int>(Y.dims()), 0);
dQ_prime = dito.Fill(phi::vectorize<int>(Q.dims()), 0);
}
if (ctx.HasInput(framework::GradVarName("Q"))) {
dQ_prime = dito.Add(dQ_prime, dQ);
}
dX = m_gt_n_case(ctx, dito, dQ_prime, dR_tmp, A, Q, U);
dY = dito.Matmul(Q, dV);
// Concatenate dX and dY to get dA.
auto dA_tmp = dito.ConcatTwoTensors(dX, dY, -1);
framework::TensorCopy(dA_tmp, dA.place(), &dA);
}
}
};
template <typename DeviceContext, typename T>
void BatchedGeqrf(const DeviceContext& dev_ctx,
int batch_size,
int m,
int n,
T* a,
int lda,
T* tau,
int a_stride,
int tau_stride);
template <typename DeviceContext, typename T>
void BatchedOrgqr(const DeviceContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
T* a,
int lda,
T* tau,
int a_stride,
int tau_stride);
} // namespace operators
} // namespace paddle
......@@ -1857,7 +1857,7 @@
func : QrInferMeta
kernel :
func : qr
# backward : qr_grad
backward : qr_grad
- api : randint
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
......
......@@ -1709,6 +1709,16 @@
kernel :
func : put_along_axis_grad
- backward_api : qr_grad
forward : qr (Tensor x, str mode) -> Tensor(q), Tensor(r)
args : (Tensor x, Tensor q, Tensor r, Tensor q_grad, Tensor r_grad, str mode)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : qr_grad
- backward_api : real_grad
forward : real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/qr_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h"
PD_REGISTER_KERNEL(qr_grad, CPU, ALL_LAYOUT, phi::QrGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/qr_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h"
PD_REGISTER_KERNEL(qr_grad, GPU, ALL_LAYOUT, phi::QrGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP // HIP not support cusolver
#include <thrust/device_vector.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/impl/qr_kernel_impl.h"
#include "paddle/phi/kernels/qr_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/kernels/tril_triu_kernel.h"
namespace phi {
template <class T, class Context>
static DenseTensor Fill(const Context& ctx,
std::vector<int> shape,
float fill_value) {
DenseTensor ret;
ret.Resize(make_ddim(shape));
ctx.template Alloc<T>(&ret);
funcs::SetConstant<Context, T>()(ctx, &ret, T(fill_value));
return ret;
}
template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
const std::string& mode,
DenseTensor* q,
DenseTensor* r) {
bool compute_q;
bool reduced_mode;
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int batch_size = numel / (m * n);
int qr_stride = m * n;
int tau_stride = min_mn;
if (compute_q) {
ctx.template Alloc<phi::dtype::Real<T>>(
q, batch_size * m * k * sizeof(phi::dtype::Real<T>));
}
ctx.template Alloc<phi::dtype::Real<T>>(
r, batch_size * k * n * sizeof(phi::dtype::Real<T>));
// Note: allocate temporary tensors because of lacking in-place operatios.
// Prepare qr
DenseTensor qr;
ctx.template Alloc<phi::dtype::Real<T>>(
&qr, size_t(batch_size * m * n * sizeof(phi::dtype::Real<T>)));
// BatchedGeqrf performs computation in-place and 'qr' must be a copy of
// input
phi::Copy(ctx, x, ctx.GetPlace(), false, &qr);
// Prepare tau
auto tau_dims_vec = phi::vectorize<int>(x_dims);
tau_dims_vec.pop_back();
tau_dims_vec[tau_dims_vec.size() - 1] = min_mn;
DenseTensor tau = Fill<T, Context>(ctx, tau_dims_vec, 0);
// Transpose 'qr' to conform the column-major order
auto tmp_qr = TransposeLast2Dim<T, Context>(ctx, qr);
phi::Copy(ctx, tmp_qr, qr.place(), false, &qr);
auto qr_data = ctx.template Alloc<phi::dtype::Real<T>>(&qr);
auto tau_data = ctx.template Alloc<phi::dtype::Real<T>>(&tau);
BatchedGeqrf<Context, T>(
ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride);
if (reduced_mode) {
auto trans_qr = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_qr = SliceKernel<T, Context>(
ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}, {1}, {});
auto tmp_r = TrilTriu<T, Context>(ctx, sliced_qr, 0, false);
// Transpose 'tmp_r' to retore the original row-major order
phi::Copy(ctx, tmp_r, r->place(), false, r);
} else {
auto trans_qr = TransposeLast2Dim<T, Context>(ctx, qr);
auto tmp_r = TrilTriu<T, Context>(ctx, trans_qr, 0, false);
// Transpose 'tmp_r' to retore the original row-major order
phi::Copy(ctx, tmp_r, r->place(), false, r);
}
if (compute_q) {
// Perform QRGQR for Q using the result from GEQRF
// Transpose 'q' to retore the original row-major order
if (reduced_mode) {
BatchedOrgqr<Context, T>(ctx,
batch_size,
m,
min_mn,
min_mn,
qr_data,
m,
tau_data,
qr_stride,
tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_q = SliceKernel<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}, {1}, {});
phi::Copy(ctx, sliced_q, q->place(), false, q);
} else {
if (m > n) {
auto new_qr_dims_vec = phi::vectorize<int>(x_dims);
new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m;
DenseTensor new_qr = Fill<T, Context>(ctx, new_qr_dims_vec, 0);
auto new_qr_data = ctx.template Alloc<phi::dtype::Real<T>>(&new_qr);
auto new_qr_stride = m * m;
for (int i = 0; i < batch_size; ++i) {
paddle::memory::Copy(ctx.GetPlace(),
(new_qr_data + i * new_qr_stride),
ctx.GetPlace(),
(qr_data + i * qr_stride),
qr_stride * sizeof(phi::dtype::Real<T>),
ctx.stream());
}
BatchedOrgqr<Context, T>(ctx,
batch_size,
m,
m,
min_mn,
new_qr_data,
m,
tau_data,
new_qr_stride,
tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, new_qr);
phi::Copy(ctx, trans_q, q->place(), false, q);
} else {
BatchedOrgqr<Context, T>(ctx,
batch_size,
m,
m,
min_mn,
qr_data,
m,
tau_data,
qr_stride,
tau_stride);
auto trans_q = TransposeLast2Dim<T, Context>(ctx, qr);
auto sliced_q = SliceKernel<T, Context>(
ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}, {1}, {});
phi::Copy(ctx, sliced_q, q->place(), false, q);
}
}
}
}
template <>
void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
float* a,
int lda,
float* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, &lwork));
DenseTensor workspace = DenseTensor();
workspace.Resize(make_ddim({lwork}));
float* workspace_ptr = dev_ctx.template Alloc<float>(&workspace);
DenseTensor info = DenseTensor();
info.Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf(handle,
m,
n,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
double* a,
int lda,
double* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork));
DenseTensor workspace = DenseTensor();
workspace.Resize(make_ddim({lwork}));
double* workspace_ptr = dev_ctx.template Alloc<double>(&workspace);
DenseTensor info = DenseTensor();
info.Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgeqrf(handle,
m,
n,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedOrgqr<GPUContext, float>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
float* a,
int lda,
float* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
DenseTensor workspace = DenseTensor();
workspace.Resize(make_ddim({lwork}));
float* workspace_ptr = dev_ctx.template Alloc<float>(&workspace);
DenseTensor info = DenseTensor();
info.Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr(handle,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
double* a,
int lda,
double* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
DenseTensor workspace = DenseTensor();
workspace.Resize(make_ddim({lwork}));
double* workspace_ptr = dev_ctx.template Alloc<double>(&workspace);
DenseTensor info = DenseTensor();
info.Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr(handle,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h));
}
}
} // namespace phi
PD_REGISTER_KERNEL(qr, // cuda_only
GPU,
ALL_LAYOUT,
phi::QrKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
#include "paddle/phi/kernels/tril_triu_kernel.h"
namespace phi {
template <class T, class Context>
static DenseTensor Fill(const Context& ctx,
std::vector<int> shape,
float fill_value) {
DenseTensor ret;
ret.Resize(make_ddim(shape));
ctx.template Alloc<T>(&ret);
funcs::SetConstant<Context, T>()(ctx, &ret, T(fill_value));
return ret;
}
template <typename T, typename Context>
void QrGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& q,
const DenseTensor& r,
const DenseTensor& q_grad,
const DenseTensor& r_grad,
const std::string& mode,
DenseTensor* x_grad) {
// Using alias names
const DenseTensor& A = x;
const DenseTensor& Q = q;
const DenseTensor& R = r;
const DenseTensor& dQ = q_grad;
const DenseTensor& dR = r_grad;
DenseTensor& dA = *x_grad;
ctx.template Alloc<phi::dtype::Real<T>>(&dA);
phi::funcs::SetConstant<Context, T>()(ctx, &dA, T(0));
bool compute_q, reduced;
std::tie(compute_q, reduced) = phi::funcs::ParseQrMode(mode);
if (!compute_q) {
PADDLE_THROW(errors::InvalidArgument(
"The derivative of qr is not implemented when mode='%s'.", mode));
}
auto a_dims = A.dims();
int a_rank = a_dims.size();
int m = a_dims[a_rank - 2];
int n = a_dims[a_rank - 1];
if ((m > n) && (!reduced)) {
PADDLE_THROW(errors::InvalidArgument(
"The derivative of qr is not implemented when mode='complete' and "
"%d > %d.",
m,
n));
}
// m >= n case
auto m_gt_n_case = [](const Context& ctx,
const DenseTensor& dQ,
const DenseTensor& dR,
const DenseTensor& A,
const DenseTensor& Q,
const DenseTensor& R) -> DenseTensor {
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable
// Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// dR^H
DenseTensor R_term;
if (dR.initialized()) {
R_term =
Matmul<T, Context>(ctx, R, TransposeLast2Dim<T, Context>(ctx, dR));
} else {
R_term = Fill<T, Context>(ctx, phi::vectorize<int>(R.dims()), 0);
}
// dQ^H * Q
DenseTensor Q_term;
if (dQ.initialized()) {
Q_term =
Matmul<T, Context>(ctx, TransposeLast2Dim<T, Context>(ctx, dQ), Q);
} else {
Q_term = Fill<T, Context>(ctx, phi::vectorize<int>(R.dims()), 0);
}
DenseTensor M_tmp1 = Subtract<T, Context>(ctx, R_term, Q_term);
// Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity
DenseTensor M_tril_0 = TrilTriu<T, Context>(ctx, M_tmp1, 0, true);
DenseTensor M_tril_1 = TrilTriu<T, Context>(ctx, M_tmp1, -1, true);
DenseTensor M = Add<T, Context>(
ctx, M_tril_0, TransposeLast2Dim<T, Context>(ctx, M_tril_1));
DenseTensor rhs_term;
if (dQ.initialized()) {
rhs_term = Add<T, Context>(ctx, dQ, Matmul<T, Context>(ctx, Q, M));
} else {
rhs_term = Matmul<T, Context>(ctx, Q, M);
}
// dA * R^H = rhs_term
auto dA = TriangularSolve<T, Context>(
ctx,
TransposeLast2Dim<T, Context>(
ctx, Conj<T, Context>(ctx, TransposeLast2Dim<T, Context>(ctx, R))),
TransposeLast2Dim<T, Context>(ctx, rhs_term),
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);
return TransposeLast2Dim<T, Context>(ctx, dA);
};
if (m >= n) {
auto dA_tmp = m_gt_n_case(ctx, dQ, dR, A, Q, R);
phi::Copy(ctx, dA_tmp, dA.place(), false, &dA);
} else {
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
ctx.template Alloc<phi::dtype::Real<T>>(&dA);
auto Y = SliceKernel<T, Context>(
ctx, A, {A.dims().size() - 1}, {m}, {n}, {1}, {});
auto U = SliceKernel<T, Context>(
ctx, R, {R.dims().size() - 1}, {0}, {m}, {1}, {});
DenseTensor dY, dX, dV, dR_tmp, dQ_prime;
if (dR.initialized()) {
dV = SliceKernel<T, Context>(
ctx, dR, {dR.dims().size() - 1}, {m}, {n}, {1}, {});
dR_tmp = SliceKernel<T, Context>(
ctx, dR, {dR.dims().size() - 1}, {0}, {m}, {1}, {});
// Y * dV^H
dQ_prime =
Matmul<T, Context>(ctx, Y, TransposeLast2Dim<T, Context>(ctx, dV));
} else {
dV = Fill<T, Context>(ctx, phi::vectorize<int>(Y.dims()), 0);
dQ_prime = Fill<T, Context>(ctx, phi::vectorize<int>(Q.dims()), 0);
}
if (dQ.initialized()) {
dQ_prime = Add<T, Context>(ctx, dQ_prime, dQ);
}
dX = m_gt_n_case(ctx, dQ_prime, dR_tmp, A, Q, U);
dY = Matmul<T, Context>(ctx, Q, dV);
// Concatenate dX and dY to get dA.
auto dA_tmp = Concat<T, Context>(ctx, {&dX, &dY}, -1);
phi::Copy(ctx, dA_tmp, dA.place(), false, &dA);
}
}
} // namespace phi
......@@ -50,225 +50,6 @@ void BatchedOrgqr(const DeviceContext& dev_ctx,
int a_stride,
int tau_stride);
template <>
void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
float* a,
int lda,
float* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, &lwork));
DenseTensor* workspace = new DenseTensor();
workspace->Resize(make_ddim({lwork}));
float* workspace_ptr = dev_ctx.template Alloc<float>(workspace);
DenseTensor* info = new DenseTensor();
info->Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(info);
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf(handle,
m,
n,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
double* a,
int lda,
double* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork));
DenseTensor* workspace = new DenseTensor();
workspace->Resize(make_ddim({lwork}));
double* workspace_ptr = dev_ctx.template Alloc<double>(workspace);
DenseTensor* info = new DenseTensor();
info->Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(info);
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgeqrf(handle,
m,
n,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedOrgqr<GPUContext, float>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
float* a,
int lda,
float* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
DenseTensor* workspace = new DenseTensor();
workspace->Resize(make_ddim({lwork}));
float* workspace_ptr = dev_ctx.template Alloc<float>(workspace);
DenseTensor* info = new DenseTensor();
info->Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(info);
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr(handle,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h));
}
}
template <>
void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
double* a,
int lda,
double* tau,
int a_stride,
int tau_stride) {
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
DenseTensor* workspace = new DenseTensor();
workspace->Resize(make_ddim({lwork}));
double* workspace_ptr = dev_ctx.template Alloc<double>(workspace);
DenseTensor* info = new DenseTensor();
info->Resize(make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(info);
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr(handle,
m,
n,
k,
a_working_ptr,
lda,
tau_working_ptr,
workspace_ptr,
lwork,
info_d));
// Do we need synchronized here?
// check the error info
int info_h;
paddle::memory::Copy(phi::CPUPlace(),
&info_h,
dev_ctx.GetPlace(),
info_d,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h));
}
}
#endif
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void QrGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& q,
const DenseTensor& r,
const DenseTensor& q_grad,
const DenseTensor& r_grad,
const std::string& mode,
DenseTensor* x_grad);
} // namespace phi
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
namespace phi {
......@@ -27,4 +28,19 @@ void TriangularSolveKernel(const Context& dev_ctx,
bool unitriangular,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor TriangularSolve(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
bool upper,
bool transpose,
bool unitriangular) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
TriangularSolveInferMeta(x, y, upper, transpose, unitriangular, &meta_out);
TriangularSolveKernel<T, Context>(
ctx, x, y, upper, transpose, unitriangular, &dense_out);
return dense_out;
}
} // namespace phi
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -25,4 +26,16 @@ void TrilTriuKernel(const Context& ctx,
bool lower,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor TrilTriu(const Context& ctx,
const DenseTensor& x,
int diagonal,
bool lower) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
TrilTriuInferMeta(x, diagonal, lower, &meta_out);
TrilTriuKernel<T, Context>(ctx, x, diagonal, lower, &dense_out);
return dense_out;
}
} // namespace phi
......@@ -20,6 +20,12 @@ KernelSignature QrOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("qr", {"X"}, {"mode"}, {"Q", "R"});
}
KernelSignature QrGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"qr_grad", {"X", "Q", "R", "Q@GRAD", "R@GRAD"}, {"mode"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(qr, phi::QrOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(qr_grad, phi::QrGradOpArgumentMapping);
......@@ -28,6 +28,7 @@ class TestQrOp(OpTest):
def setUp(self):
paddle.enable_static()
self.python_api = paddle.linalg.qr
np.random.seed(7)
self.op_type = "qr"
a, q, r = self.get_input_and_output()
......@@ -72,10 +73,11 @@ class TestQrOp(OpTest):
return a, q, r
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X'], ['Q', 'R'],
check_eager=True,
numeric_grad_delta=1e-5,
max_relative_error=1e-6)
......
......@@ -1998,7 +1998,13 @@ def qr(x, mode="reduced", name=None):
# one can verify : X = Q * R ;
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
q, r = _C_ops.final_state_qr(x, mode)
if mode == "r":
return r
else:
return q, r
if _in_legacy_dygraph():
q, r = _C_ops.qr(x, 'mode', mode)
if mode == "r":
return r
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册