未验证 提交 a04a6bd5 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move determinant op kernel into phi (#40539)

* add determinant phi kernel

* remove original determinant op kernel

* add determinant grad [hi kernel

* fix determinant test failed

* remove original determinant grad op kernel
上级 0c0acbd7
......@@ -168,14 +168,6 @@ REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker,
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp)
REGISTER_OP_CPU_KERNEL(determinant,
ops::DeterminantKernel<plat::CPUDeviceContext, float>,
ops::DeterminantKernel<plat::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
determinant_grad, ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantOpMaker,
ops::SlogDeterminantGradOpMaker<paddle::framework::OpDesc>,
......
......@@ -17,14 +17,6 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
determinant, ops::DeterminantKernel<plat::CUDADeviceContext, float>,
ops::DeterminantKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
determinant_grad,
ops::DeterminantGradKernel<plat::CUDADeviceContext, float>,
ops::DeterminantGradKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
slogdeterminant, ops::SlogDeterminantKernel<plat::CUDADeviceContext, float>,
......
......@@ -23,10 +23,13 @@
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
......@@ -40,232 +43,6 @@ T sign(T val) {
return static_cast<T>(T(0) < val) - (val < T(0));
}
template <typename T>
class EigenMatrix {};
template <>
class EigenMatrix<float> {
public:
using MatrixType = Eigen::MatrixXf;
};
template <>
class EigenMatrix<double> {
public:
using MatrixType = Eigen::MatrixXd;
};
inline int64_t GetBatchCount(const framework::DDim dims) {
int64_t batch_count = 1;
auto dim_size = dims.size();
PADDLE_ENFORCE_GE(
dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
// for example a tensor with shape [3,3,3,3], the batch count of matrices is
// 9.
for (int64_t i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
return batch_count;
}
template <typename T>
struct DeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
int64_t rank, int64_t batch_count, Tensor* output) {
std::vector<T> input_vec;
std::vector<T> output_vec;
framework::TensorToVector(input, ctx.device_context(), &input_vec);
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
}
}
output_vec.push_back(matrix.determinant());
}
framework::TensorFromVector(output_vec, output);
}
};
template <typename DeviceContext, typename T>
class DeterminantKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
auto input_dim = vectorize(input->dims());
auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out");
auto batch_count = GetBatchCount(input->dims());
VLOG(2) << "input dim:" << input->dims();
PADDLE_ENFORCE_GE(
input_dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1],
input_dim[input_dim_size - 2],
platform::errors::InvalidArgument(
"the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
DeterminantFunctor<T>()(*input, context, rank, batch_count, output);
auto output_dims = phi::slice_ddim(input->dims(), 0, input_dim_size - 2);
if (input_dim_size > 2) {
output->Resize(output_dims);
} else {
// when input is a two-dimension matrix, The det value is a number.
output->Resize({1});
}
VLOG(2) << "output dim:" << output->dims();
}
};
template <typename T>
struct FoundZeroFunctor {
FoundZeroFunctor(const T* x, int64_t numel, bool* res)
: x_(x), numel_(numel), res_(res) {}
HOSTDEVICE void operator()(size_t idx) const {
if (*res_ || idx >= static_cast<size_t>(numel_)) {
// founded zero number
return;
}
*res_ = (x_[idx] == static_cast<T>(0));
}
const T* x_;
int64_t numel_;
bool* res_;
};
template <typename DeviceContext, typename T>
inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx,
const framework::Tensor* det) {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto numel = det->numel();
framework::Tensor dev_tensor;
auto* data = dev_tensor.mutable_data<bool>({1}, ctx.GetPlace());
// set false
phi::funcs::SetConstant<DeviceContext, bool> zero;
zero(dev_ctx, &dev_tensor, false);
// find whether zero
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
FoundZeroFunctor<T> functor(det->data<T>(), numel, data);
for_range(functor);
// copy to host
dev_ctx.Wait();
framework::Tensor cpu_tensor;
framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor);
// if founded zero, the matrix is not invertible
// else the matrix is invertible
auto* res = cpu_tensor.data<bool>();
return !(*res);
}
template <typename DeviceContext, typename T>
class DeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* det = context.Input<framework::Tensor>("Out");
const auto* grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ddet =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
auto input_dims_size = input->dims().size();
if (input_dims_size > 2) {
PADDLE_ENFORCE_EQ(
grad->dims().size() + 2, input_dims_size,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else if (input_dims_size == 2) {
// input dims size 2 and grad dims size 1 is possible
PADDLE_ENFORCE_EQ(
grad->dims().size(), 1,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else {
// checked in forward, pass
}
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
ddet->Resize(input->dims());
phi::Full<T>(dev_ctx, phi::vectorize(input->dims()), static_cast<T>(0.0f),
ddet);
return;
}
// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)
// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
inverse_A.Resize(input->dims());
inverse_A.mutable_data<T>(context.GetPlace());
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(orig_dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1)
framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
// Third: dA * |A|
auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
framework::TensorCopy(res, context.GetPlace(), ddet);
ddet->Resize(input->dims());
VLOG(3) << "d|A| dims: " << ddet->dims();
}
};
template <typename T>
struct SlogDeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
......@@ -280,7 +57,7 @@ struct SlogDeterminantFunctor {
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
typename phi::detail::EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
......@@ -311,7 +88,7 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out");
auto batch_count = GetBatchCount(input->dims());
auto batch_count = phi::detail::GetBatchCount(input->dims());
VLOG(2) << "input dim:" << input->dims();
PADDLE_ENFORCE_GE(
input_dim_size, 2,
......@@ -370,7 +147,9 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// (matrix A not invertible) == (absslogdet(A)=0)
auto slogdet_vec = slogdet->Split(1, 0);
auto absslogdet_val = slogdet_vec[0];
if (!CheckMatrixInvertible<DeviceContext, T>(context, &absslogdet_val)) {
if (!phi::detail::CheckMatrixInvertible<
T, typename framework::ConvertToPhiContext<DeviceContext>::TYPE>(
dev_ctx, &absslogdet_val)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
dslogdet->Resize(input->dims());
......
......@@ -27,7 +27,11 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel math_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel triangular_solve_grad_kernel)
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel math_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(gumbel_softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(gumbel_softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
......@@ -46,6 +50,7 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
# 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
PD_REGISTER_KERNEL(determinant_grad,
CPU,
ALL_LAYOUT,
phi::DeterminantGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/determinant_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
PD_REGISTER_KERNEL(
determinant, CPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DeterminantGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DeterminantKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
PD_REGISTER_KERNEL(determinant_grad,
GPU,
ALL_LAYOUT,
phi::DeterminantGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/determinant_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
PD_REGISTER_KERNEL(
determinant, GPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace detail {
template <typename T>
struct FoundZeroFunctor {
FoundZeroFunctor(const T* x, int64_t numel, bool* res)
: x_(x), numel_(numel), res_(res) {}
HOSTDEVICE void operator()(size_t idx) const {
if (*res_ || idx >= static_cast<size_t>(numel_)) {
// founded zero number
return;
}
*res_ = (x_[idx] == static_cast<T>(0));
}
const T* x_;
int64_t numel_;
bool* res_;
};
template <typename T, typename Context>
inline bool CheckMatrixInvertible(const Context& dev_ctx,
const DenseTensor* det) {
auto numel = det->numel();
DenseTensor dev_tensor = phi::Empty<bool, Context>(dev_ctx, {1});
// set false
phi::funcs::SetConstant<Context, bool> zero;
zero(dev_ctx, &dev_tensor, false);
// find whether zero
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
FoundZeroFunctor<T> functor(det->data<T>(), numel, dev_tensor.data<bool>());
for_range(functor);
// copy to host
DenseTensor cpu_tensor;
phi::Copy<Context>(dev_ctx, dev_tensor, phi::CPUPlace(), false, &cpu_tensor);
// if founded zero, the matrix is not invertible
// else the matrix is invertible
auto* res = cpu_tensor.data<bool>();
return !(*res);
}
} // namespace detail
template <typename T, typename Context>
void DeterminantGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto input_dims_size = x.dims().size();
if (input_dims_size > 2) {
PADDLE_ENFORCE_EQ(
out_grad.dims().size() + 2,
input_dims_size,
phi::errors::InvalidArgument(
"The grad tensor of det dims size should be 2 less than"
" input tensor's, but here differ %d",
input_dims_size - out_grad.dims().size()));
} else if (input_dims_size == 2) {
// input dims size 2 and grad dims size 1 is possible
PADDLE_ENFORCE_EQ(
out_grad.dims().size(),
1,
phi::errors::InvalidArgument(
"The grad tensor of det dims size should be 2 less than"
" input tensor's, but here differ %d",
input_dims_size - out_grad.dims().size()));
} else {
// checked in forward, pass
}
// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if (!detail::CheckMatrixInvertible<T, Context>(dev_ctx, &out)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
x_grad->Resize(x.dims());
phi::Full<T>(
dev_ctx, phi::vectorize(x.dims()), static_cast<T>(0.0f), x_grad);
return;
}
// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)
// First: inverse(A)
DenseTensor inverse_A;
// A must be square matrices!
inverse_A.Resize(x.dims());
dev_ctx.template Alloc<T>(&inverse_A);
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
mat_inv(dev_ctx, x, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1)
DenseTensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
// Third: dA * |A|
auto mul_dA_detA = phi::Multiply<T>(dev_ctx, out_grad, out);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
x_grad->Resize(x.dims());
VLOG(3) << "d|A| dims: " << x_grad->dims();
phi::Copy(dev_ctx, res, dev_ctx.GetPlace(), false, x_grad);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/determinant_kernel.h"
#include <Eigen/Dense>
#include <Eigen/LU>
#include <algorithm>
#include <cmath>
#include <vector>
#include "paddle/phi/core/enforce.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace phi {
namespace detail {
template <typename T>
class EigenMatrix {};
template <>
class EigenMatrix<float> {
public:
using MatrixType = Eigen::MatrixXf;
};
template <>
class EigenMatrix<double> {
public:
using MatrixType = Eigen::MatrixXd;
};
inline int64_t GetBatchCount(const DDim dims) {
int64_t batch_count = 1;
auto dim_size = dims.size();
PADDLE_ENFORCE_GE(
dim_size,
2,
phi::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
// for example a tensor with shape [3,3,3,3], the batch count of matrices is
// 9.
for (int64_t i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
return batch_count;
}
} // namespace detail
template <typename T, typename Context>
struct DeterminantFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& input,
int64_t rank,
int64_t batch_count,
DenseTensor* output) {
std::vector<T> input_vec;
std::vector<T> output_vec;
paddle::framework::TensorToVector(input, dev_ctx, &input_vec);
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename detail::EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
}
}
output_vec.push_back(matrix.determinant());
}
paddle::framework::TensorFromVector(output_vec, output);
}
};
template <typename T, typename Context>
void DeterminantKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto input_dim = vectorize(x.dims());
auto input_dim_size = input_dim.size();
auto batch_count = detail::GetBatchCount(x.dims());
VLOG(10) << "input dim:" << x.dims();
PADDLE_ENFORCE_GE(
input_dim_size,
2,
phi::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1],
input_dim[input_dim_size - 2],
phi::errors::InvalidArgument(
"the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
DeterminantFunctor<T, Context>()(dev_ctx, x, rank, batch_count, out);
auto output_dims = phi::slice_ddim(x.dims(), 0, input_dim_size - 2);
if (input_dim_size > 2) {
out->Resize(output_dims);
} else {
// when input is a two-dimension matrix, The det value is a number.
out->Resize({1});
}
VLOG(10) << "output dim:" << out->dims();
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DeterminantGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("determinant_grad",
{"Input", "Out", GradVarName("Out")},
{},
{GradVarName("Input")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(determinant_grad,
phi::DeterminantGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册