未验证 提交 930a5136 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Phi] Migrate triangular_solve dependence to phi (#40417)

上级 89a70c76
......@@ -17,9 +17,11 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/lstsq_op.h"
#include "paddle/fluid/operators/qr_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
namespace paddle {
namespace operators {
......@@ -70,6 +72,10 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
Tensor tau = dito.Fill(tau_dims_vec, 0);
auto tau_data = tau.mutable_data<T>(context.GetPlace());
using Context =
typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
auto& phi_dev_ctx = static_cast<const Context&>(dev_ctx);
if (m >= n) {
Tensor tmp_x = dito.Transpose(new_x);
Tensor tmp_y = dito.Transpose(new_y);
......@@ -93,8 +99,9 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn});
// Step 3, solve R X = Y
triangular_solve<DeviceContext, T>(dev_ctx, res_r, slice_y, solution,
true, false, false);
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, res_r, slice_y, true,
false, false, solution);
} else {
auto x_data = new_x.mutable_data<T>(context.GetPlace());
auto y_data = new_y.mutable_data<T>(context.GetPlace());
......@@ -105,8 +112,8 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
// Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x);
triangular_solve<DeviceContext, T>(dev_ctx, trans_r, new_y, solution,
true, true, false);
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, trans_r, new_y, true,
true, false, solution);
// Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx, batch_count, n, n, min_mn, x_data,
......
......@@ -22,7 +22,6 @@
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
......
......@@ -15,12 +15,13 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/set_value_op.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
namespace paddle {
namespace operators {
......@@ -555,6 +556,11 @@ class LUGradKernel : public framework::OpKernel<T> {
framework::Tensor Pmat;
Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, &Pmat, m, k);
using Context =
typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
auto& phi_dev_ctx = static_cast<const Context&>(dev_ctx);
if (m <= n) {
if (k < n) {
framework::Tensor U_complement, U_grad_complement, phi_complement,
......@@ -605,8 +611,9 @@ class LUGradKernel : public framework::OpKernel<T> {
framework::Tensor psi_principal, phi_mH, psi_tmp;
Tensor_Conj<DeviceContext, T>(dev_ctx, phi, &phi_mH);
phi_mH = helper.Transpose(phi_mH);
triangular_solve<DeviceContext, T>(dev_ctx, U_narrow, phi_mH,
&psi_principal, true, false, false);
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, U_narrow, phi_mH, true, false, false, &psi_principal);
Tensor_Conj<DeviceContext, T>(dev_ctx, psi_principal, &psi_principal);
psi_principal = helper.Transpose(psi_principal);
......@@ -620,8 +627,9 @@ class LUGradKernel : public framework::OpKernel<T> {
SetValueCompute_dispatch<DeviceContext, T>(ctx, &psi, &psi_principal,
&psi, axes, &slice_starts,
&slice_ends, valuedims, xrank);
triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, psi, &psi_tmp,
true, false, true);
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, L_narrow_mH, psi,
true, false, true, &psi_tmp);
auto mat_dim_p =
phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false);
......@@ -672,8 +680,10 @@ class LUGradKernel : public framework::OpKernel<T> {
&psi, axes, &slice_starts,
&slice_ends, valuedims, xrank);
framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH;
triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, phi,
&psi_principal, true, false, true);
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, L_narrow_mH, phi,
true, false, true, &psi_principal);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
......@@ -695,8 +705,8 @@ class LUGradKernel : public framework::OpKernel<T> {
psi_tmp = helper.Transpose(psi_tmp);
Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
triangular_solve<DeviceContext, T>(dev_ctx, U_narrow_mH, psi_tmp, &psi,
true, false, false);
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, U_narrow_mH, psi_tmp,
true, false, false, &psi);
*dx = helper.Transpose(psi);
}
}
......
......@@ -34,45 +34,6 @@ class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;
template <typename T>
class TriangularSolveFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor* a, framework::Tensor* b, bool left,
bool upper, bool transpose, bool unitriangular) {
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
const T* a_data = a->data<T>();
T* b_data = b->mutable_data<T>(context.GetPlace());
int a_dim_size = a->dims().size();
int b_dim_size = b->dims().size();
int M = static_cast<int>(b->dims()[b_dim_size - 2]);
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
auto lda = left ? std::max(1, M) : std::max(1, N);
auto ldb = std::max(1, N);
int batch_size = 1;
auto& a_dim = a->dims();
for (int i = 0; i < a_dim_size - 2; i++) {
batch_size *= a_dim[i];
}
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < batch_size; i++) {
blas.TRSM(side, uplo, transA, diag, M, N, T(1), a_data + i * M * M, lda,
b_data + i * N * M, ldb);
}
}
};
template class TriangularSolveFunctor<platform::CPUDeviceContext, float>;
template class TriangularSolveFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -161,67 +161,6 @@ class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;
template <typename T>
class TriangularSolveFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const Tensor* a,
Tensor* b, bool left, bool upper, bool transpose,
bool unitriangular) {
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
const T* a_data = a->data<T>();
T* b_data = b->mutable_data<T>(context.GetPlace());
int a_dim_size = a->dims().size();
int b_dim_size = b->dims().size();
int M = static_cast<int>(b->dims()[b_dim_size - 2]);
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
auto lda = left ? std::max(1, M) : std::max(1, N);
auto ldb = std::max(1, N);
int batch_size = 1;
auto& a_dim = a->dims();
for (int i = 0; i < a_dim_size - 2; i++) {
batch_size *= a_dim[i];
}
auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(context);
if (batch_size <= 8 && M >= 64) {
for (auto i = 0; i < batch_size; i++) {
blas.TRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
a_data + i * M * M, lda, b_data + i * N * M, ldb);
}
} else {
std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = a_data + i * M * M;
cpu_ptrs[i + batch_size] = b_data + i * M * N;
}
// Copy the addresses of A and tmp_b from host to device.
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
memory::Copy(context.GetPlace(), tmp_gpu_ptrs_data->ptr(),
platform::CPUPlace(), static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*), context.stream());
const T** gpu_a_ptrs =
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr());
T** gpu_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
blas.BatchedTRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
gpu_a_ptrs, lda, gpu_b_ptrs, ldb, batch_size);
}
}
};
template class TriangularSolveFunctor<platform::CUDADeviceContext, float>;
template class TriangularSolveFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -117,14 +117,6 @@ class MatrixSolveFunctor {
const framework::Tensor& b, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
class TriangularSolveFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* a,
framework::Tensor* b, bool left, bool upper, bool transpose,
bool unitriangular);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
static void triangular_solve(const DeviceContext &context, const Tensor &x,
const Tensor &y, Tensor *out, bool upper,
bool transpose, bool unitriangular) {
// Tensor broadcast use eigen library
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y);
Tensor x_bst(x.type());
TensorExpand<T, DeviceContext>(context, x, &x_bst, x_bst_dims_vec);
Tensor y_bst(y.type());
TensorExpand<T, DeviceContext>(context, y, &y_bst, y_bst_dims_vec);
// TriangularSolveFunctor performs calculations in-place
// x_clone should be a copy of 'x' after broadcast
// out should be a copy of 'y' after broadcast
Tensor x_clone(x.type());
x_clone.Resize(phi::make_ddim(x_bst_dims_vec));
x_clone.mutable_data<T>(context.GetPlace());
framework::TensorCopy(x_bst, context.GetPlace(), context, &x_clone);
out->Resize(phi::make_ddim(y_bst_dims_vec));
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(y_bst, context.GetPlace(), context, out);
math::TriangularSolveFunctor<DeviceContext, T> functor;
functor(context, &x_clone, out, /*left=*/true, upper, transpose,
unitriangular);
}
} // namespace operators
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/triangular_solve_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册