未验证 提交 aeaf69b3 编写于 作者: C Chen Weihang 提交者: GitHub

remove determinant deps for svd helper (#40235)

上级 7ea9235c
......@@ -19,11 +19,17 @@
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle {
namespace operators {
......@@ -172,7 +178,7 @@ template <typename DeviceContext, typename T>
class DeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* det = context.Input<framework::Tensor>("Out");
const auto* grad =
......@@ -200,15 +206,18 @@ class DeterminantGradKernel : public framework::OpKernel<T> {
// checked in forward, pass
}
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
ddet->Resize(input->dims());
ddet->mutable_data<T>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, ddet, static_cast<T>(0.0f));
phi::Full<T>(dev_ctx, phi::vectorize(input->dims()), static_cast<T>(0.0f),
ddet);
return;
}
......@@ -218,8 +227,6 @@ class DeterminantGradKernel : public framework::OpKernel<T> {
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(context);
// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
......@@ -227,26 +234,28 @@ class DeterminantGradKernel : public framework::OpKernel<T> {
inverse_A.mutable_data<T>(context.GetPlace());
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *input, &inverse_A);
mat_inv(orig_dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1)
framework::Tensor transpose_inverse_A = helper.Transpose(inverse_A);
framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
// Third: dA * |A|
auto mul_dA_detA = helper.Mul(*grad, *det);
auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = helper.Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2);
auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = helper.Mul(unsqueeze2, transpose_inverse_A);
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
......@@ -331,7 +340,7 @@ template <typename DeviceContext, typename T>
class SlogDeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* slogdet = context.Input<framework::Tensor>("Out");
const auto* grad =
......@@ -353,6 +362,10 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
input->dims().size() - grad->dims().size()));
}
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// Check Whether the matrix is invertible
// (matrix A not invertible) == (absslogdet(A)=0)
auto slogdet_vec = slogdet->Split(1, 0);
......@@ -361,9 +374,8 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
dslogdet->Resize(input->dims());
dslogdet->mutable_data<T>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, dslogdet, std::numeric_limits<T>::quiet_NaN());
phi::Full<T>(dev_ctx, phi::vectorize(input->dims()),
std::numeric_limits<T>::quiet_NaN(), dslogdet);
return;
}
......@@ -373,8 +385,6 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// we set dsl|A| = unsqueeze(dslA, [-1, -2]) *
// inverse(A).conj().transpose(-2, -1)
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(context);
// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
......@@ -382,25 +392,18 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
inverse_A.mutable_data<T>(context.GetPlace());
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *input, &inverse_A);
mat_inv(orig_dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).conj()
framework::Tensor conj_inverse_A;
conj_inverse_A.Resize(inverse_A.dims());
auto numel = input->numel();
auto* conj_data = conj_inverse_A.mutable_data<T>(context.GetPlace(),
size_t(numel * sizeof(T)));
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor(inverse_A.data<T>(), numel, conj_data);
for_range(functor);
auto conj_inverse_A = phi::Conj<T>(dev_ctx, inverse_A);
VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims();
// Third: inverse(A).conj().transpose(-2, -1)
framework::Tensor transpose_inverse_A = helper.Transpose(conj_inverse_A);
framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, conj_inverse_A);
VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
......@@ -417,12 +420,12 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
det_grad.Resize(det_grad.dims().reshape(det_grad_vec));
// Fifth: unsqueeze(dslA, [-1, -2])
auto unsqueeze1 = helper.Unsqueeze(det_grad, -1);
auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2);
auto unsqueeze1 = phi::funcs::Unsqueeze(det_grad, -1);
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dslA) * inverse(A)
auto res = helper.Mul(unsqueeze2, transpose_inverse_A);
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims();
framework::TensorCopy(res, context.GetPlace(), dslogdet);
......
......@@ -37,6 +37,18 @@ void FullLikeKernel(const Context& dev_ctx,
DataType dtype,
DenseTensor* out);
template <typename T, typename Context>
void Full(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
FullKernel<T, Context>(dev_ctx,
shape,
val,
paddle::experimental::CppTypeToDataType<T>::Type(),
out);
}
template <typename T, typename Context>
DenseTensor Full(const Context& dev_ctx,
const ScalarArray& shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册