未验证 提交 aeaf69b3 编写于 作者: C Chen Weihang 提交者: GitHub

remove determinant deps for svd helper (#40235)

上级 7ea9235c
...@@ -19,11 +19,17 @@ ...@@ -19,11 +19,17 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -172,7 +178,7 @@ template <typename DeviceContext, typename T> ...@@ -172,7 +178,7 @@ template <typename DeviceContext, typename T>
class DeterminantGradKernel : public framework::OpKernel<T> { class DeterminantGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input"); const auto* input = context.Input<framework::Tensor>("Input");
const auto* det = context.Input<framework::Tensor>("Out"); const auto* det = context.Input<framework::Tensor>("Out");
const auto* grad = const auto* grad =
...@@ -200,15 +206,18 @@ class DeterminantGradKernel : public framework::OpKernel<T> { ...@@ -200,15 +206,18 @@ class DeterminantGradKernel : public framework::OpKernel<T> {
// checked in forward, pass // checked in forward, pass
} }
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// Check Whether the matrix is invertible // Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0) // (matrix A not invertible) == (det(A)=0)
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) { if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
// The matrix is not invertible // The matrix is not invertible
VLOG(3) << "The input matrix not invertible!"; VLOG(3) << "The input matrix not invertible!";
ddet->Resize(input->dims()); ddet->Resize(input->dims());
ddet->mutable_data<T>(context.GetPlace()); phi::Full<T>(dev_ctx, phi::vectorize(input->dims()), static_cast<T>(0.0f),
phi::funcs::SetConstant<DeviceContext, T> zero; ddet);
zero(dev_ctx, ddet, static_cast<T>(0.0f));
return; return;
} }
...@@ -218,8 +227,6 @@ class DeterminantGradKernel : public framework::OpKernel<T> { ...@@ -218,8 +227,6 @@ class DeterminantGradKernel : public framework::OpKernel<T> {
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2, // we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1) // -1)
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(context);
// First: inverse(A) // First: inverse(A)
framework::Tensor inverse_A; framework::Tensor inverse_A;
// A must be square matrices! // A must be square matrices!
...@@ -227,26 +234,28 @@ class DeterminantGradKernel : public framework::OpKernel<T> { ...@@ -227,26 +234,28 @@ class DeterminantGradKernel : public framework::OpKernel<T> {
inverse_A.mutable_data<T>(context.GetPlace()); inverse_A.mutable_data<T>(context.GetPlace());
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv; phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *input, &inverse_A); mat_inv(orig_dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1) // Second: inverse(A).transpose(-2, -1)
framework::Tensor transpose_inverse_A = helper.Transpose(inverse_A); framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: " VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims(); << transpose_inverse_A.dims();
// Third: dA * |A| // Third: dA * |A|
auto mul_dA_detA = helper.Mul(*grad, *det); auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims(); VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
// Fourth: unsqueeze(dA * |A|, [-1, -2]) // Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = helper.Unsqueeze(mul_dA_detA, -1); auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims(); VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A) // Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = helper.Mul(unsqueeze2, transpose_inverse_A); auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims(); VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
...@@ -331,7 +340,7 @@ template <typename DeviceContext, typename T> ...@@ -331,7 +340,7 @@ template <typename DeviceContext, typename T>
class SlogDeterminantGradKernel : public framework::OpKernel<T> { class SlogDeterminantGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input"); const auto* input = context.Input<framework::Tensor>("Input");
const auto* slogdet = context.Input<framework::Tensor>("Out"); const auto* slogdet = context.Input<framework::Tensor>("Out");
const auto* grad = const auto* grad =
...@@ -353,6 +362,10 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> { ...@@ -353,6 +362,10 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
input->dims().size() - grad->dims().size())); input->dims().size() - grad->dims().size()));
} }
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// Check Whether the matrix is invertible // Check Whether the matrix is invertible
// (matrix A not invertible) == (absslogdet(A)=0) // (matrix A not invertible) == (absslogdet(A)=0)
auto slogdet_vec = slogdet->Split(1, 0); auto slogdet_vec = slogdet->Split(1, 0);
...@@ -361,9 +374,8 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> { ...@@ -361,9 +374,8 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// The matrix is not invertible // The matrix is not invertible
VLOG(3) << "The input matrix not invertible!"; VLOG(3) << "The input matrix not invertible!";
dslogdet->Resize(input->dims()); dslogdet->Resize(input->dims());
dslogdet->mutable_data<T>(context.GetPlace()); phi::Full<T>(dev_ctx, phi::vectorize(input->dims()),
phi::funcs::SetConstant<DeviceContext, T> zero; std::numeric_limits<T>::quiet_NaN(), dslogdet);
zero(dev_ctx, dslogdet, std::numeric_limits<T>::quiet_NaN());
return; return;
} }
...@@ -373,8 +385,6 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> { ...@@ -373,8 +385,6 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// we set dsl|A| = unsqueeze(dslA, [-1, -2]) * // we set dsl|A| = unsqueeze(dslA, [-1, -2]) *
// inverse(A).conj().transpose(-2, -1) // inverse(A).conj().transpose(-2, -1)
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(context);
// First: inverse(A) // First: inverse(A)
framework::Tensor inverse_A; framework::Tensor inverse_A;
// A must be square matrices! // A must be square matrices!
...@@ -382,25 +392,18 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> { ...@@ -382,25 +392,18 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
inverse_A.mutable_data<T>(context.GetPlace()); inverse_A.mutable_data<T>(context.GetPlace());
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv; phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *input, &inverse_A); mat_inv(orig_dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).conj() // Second: inverse(A).conj()
framework::Tensor conj_inverse_A; auto conj_inverse_A = phi::Conj<T>(dev_ctx, inverse_A);
conj_inverse_A.Resize(inverse_A.dims());
auto numel = input->numel();
auto* conj_data = conj_inverse_A.mutable_data<T>(context.GetPlace(),
size_t(numel * sizeof(T)));
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor(inverse_A.data<T>(), numel, conj_data);
for_range(functor);
VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims(); VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims();
// Third: inverse(A).conj().transpose(-2, -1) // Third: inverse(A).conj().transpose(-2, -1)
framework::Tensor transpose_inverse_A = helper.Transpose(conj_inverse_A); framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, conj_inverse_A);
VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: " VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: "
<< transpose_inverse_A.dims(); << transpose_inverse_A.dims();
...@@ -417,12 +420,12 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> { ...@@ -417,12 +420,12 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
det_grad.Resize(det_grad.dims().reshape(det_grad_vec)); det_grad.Resize(det_grad.dims().reshape(det_grad_vec));
// Fifth: unsqueeze(dslA, [-1, -2]) // Fifth: unsqueeze(dslA, [-1, -2])
auto unsqueeze1 = helper.Unsqueeze(det_grad, -1); auto unsqueeze1 = phi::funcs::Unsqueeze(det_grad, -1);
auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims(); VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dslA) * inverse(A) // Finally: unsqueeze(dslA) * inverse(A)
auto res = helper.Mul(unsqueeze2, transpose_inverse_A); auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims(); VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims();
framework::TensorCopy(res, context.GetPlace(), dslogdet); framework::TensorCopy(res, context.GetPlace(), dslogdet);
......
...@@ -37,6 +37,18 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -37,6 +37,18 @@ void FullLikeKernel(const Context& dev_ctx,
DataType dtype, DataType dtype,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void Full(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
FullKernel<T, Context>(dev_ctx,
shape,
val,
paddle::experimental::CppTypeToDataType<T>::Type(),
out);
}
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Full(const Context& dev_ctx, DenseTensor Full(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册