diff --git a/paddle/phi/kernels/gpu/determinant_grad_kernel.cu b/paddle/phi/kernels/gpu/determinant_grad_kernel.cu index cce12a87fac72f5ac6edbbeb74de9fe3ae9ede09..f3187d5fefb5191a3d42cfa6f94a4fdeb035a7e1 100644 --- a/paddle/phi/kernels/gpu/determinant_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/determinant_grad_kernel.cu @@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(determinant_grad, GPU, ALL_LAYOUT, phi::DeterminantGradKernel, + phi::dtype::float16, float, double) {} diff --git a/paddle/phi/kernels/gpu/determinant_kernel.cu b/paddle/phi/kernels/gpu/determinant_kernel.cu index 25184083873952638a1f84d8d4b66262363ca9c6..58e27e3ce4abdaebd2411c97daf85f92638a8c69 100644 --- a/paddle/phi/kernels/gpu/determinant_kernel.cu +++ b/paddle/phi/kernels/gpu/determinant_kernel.cu @@ -17,5 +17,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h" -PD_REGISTER_KERNEL( - determinant, GPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {} +PD_REGISTER_KERNEL(determinant, + GPU, + ALL_LAYOUT, + phi::DeterminantKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h index 3f463e1d9e0644f62a0b3e8508f333731c466922..4d58698c64d228d34d35b867f58134cee4ea7db9 100644 --- a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h @@ -15,8 +15,10 @@ #pragma once #include "glog/logging.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/determinant_grad_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" @@ -26,7 +28,6 @@ #include "paddle/phi/kernels/funcs/matrix_inverse.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/transpose_kernel.h" - namespace phi { namespace detail { @@ -113,6 +114,11 @@ void DeterminantGradKernel(const Context& dev_ctx, return; } + using MPType = typename phi::dtype::MPTypeTrait::Type; + auto origin_dt = std::is_same::value + ? DataType::FLOAT16 + : DataType::BFLOAT16; + // The matrix is invertible // let |A| = Determinant(A) // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf @@ -123,16 +129,22 @@ void DeterminantGradKernel(const Context& dev_ctx, DenseTensor inverse_A; // A must be square matrices! inverse_A.Resize(x.dims()); - dev_ctx.template Alloc(&inverse_A); + dev_ctx.template Alloc(&inverse_A); - phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, x, &inverse_A); + phi::funcs::MatrixInverseFunctor mat_inv; + if (!std::is_same::value) { + mat_inv(dev_ctx, + phi::Cast(dev_ctx, x, DataType::FLOAT32), + &inverse_A); + } else { + mat_inv(dev_ctx, x, &inverse_A); + } VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); // Second: inverse(A).transpose(-2, -1) DenseTensor transpose_inverse_A = - phi::TransposeLast2Dim(dev_ctx, inverse_A); + phi::TransposeLast2Dim(dev_ctx, inverse_A); VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: " << transpose_inverse_A.dims(); @@ -147,7 +159,15 @@ void DeterminantGradKernel(const Context& dev_ctx, VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims(); // Finally: unsqueeze(dA * |A|) * inverse(A) - auto res = phi::Multiply(dev_ctx, unsqueeze2, transpose_inverse_A); + DenseTensor res; + if (!std::is_same::value) { + res = phi::Multiply( + dev_ctx, + unsqueeze2, + phi::Cast(dev_ctx, transpose_inverse_A, origin_dt)); + } else { + res = phi::Multiply(dev_ctx, unsqueeze2, transpose_inverse_A); + } VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims(); diff --git a/paddle/phi/kernels/impl/determinant_kernel_impl.h b/paddle/phi/kernels/impl/determinant_kernel_impl.h index 36e47c78c832c10fc8613604c47239508f44e725..3c437ad659c43ac3c7556b149e0f13cbcffe65d5 100644 --- a/paddle/phi/kernels/impl/determinant_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_kernel_impl.h @@ -21,6 +21,7 @@ #include #include "glog/logging.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/tensor_utils.h" @@ -31,6 +32,13 @@ namespace detail { template class EigenMatrix {}; +template <> +class EigenMatrix { + public: + using MatrixType = + Eigen::Matrix; +}; + template <> class EigenMatrix { public: @@ -74,6 +82,7 @@ struct DeterminantFunctor { std::vector input_vec; std::vector output_vec; phi::TensorToVector(input, dev_ctx, &input_vec); + using MPType = typename phi::dtype::MPTypeTrait::Type; for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel auto begin_iter = input_vec.begin() + i * rank * rank; auto end_iter = input_vec.begin() + (i + 1) * rank * rank; @@ -85,7 +94,8 @@ struct DeterminantFunctor { matrix(i, j) = sub_vec[rank * i + j]; } } - output_vec.push_back(matrix.determinant()); + output_vec.push_back( + static_cast(matrix.template cast().determinant())); } phi::TensorFromVector(output_vec, dev_ctx, output); } diff --git a/python/paddle/fluid/tests/unittests/test_determinant_op.py b/python/paddle/fluid/tests/unittests/test_determinant_op.py index ade000cda8712e5ab99e9fd984be4f9eb358a64f..8e50f0c5552ec35124840e984a5ed7b4770817f8 100644 --- a/python/paddle/fluid/tests/unittests/test_determinant_op.py +++ b/python/paddle/fluid/tests/unittests/test_determinant_op.py @@ -50,6 +50,14 @@ class TestDeterminantOpCase1(TestDeterminantOp): self.target = np.linalg.det(self.case) +class TestDeterminantOpCase1FP16(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(10, 10).astype(np.float16) + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case.astype(np.float32)) + + class TestDeterminantOpCase2(TestDeterminantOp): def init_data(self): np.random.seed(0) @@ -59,6 +67,17 @@ class TestDeterminantOpCase2(TestDeterminantOp): self.target = np.linalg.det(self.case) +class TestDeterminantOpCase2FP16(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + # not invertible matrix + self.case = np.ones([4, 2, 4, 4]).astype(np.float16) + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case.astype(np.float32)).astype( + np.float16 + ) + + class TestDeterminantAPI(unittest.TestCase): def setUp(self): np.random.seed(0) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 3dcbc7c6ac63b26d6d63f09e13a2dff24f29bfe5..2235cf93cfb60d9693f35a19aeb7dc98db71763e 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1809,7 +1809,7 @@ def det(x, name=None): if in_dygraph_mode(): return _C_ops.det(x) else: - check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det') + check_dtype(x.dtype, 'Input', ['float16', 'float32', 'float64'], 'det') input_shape = list(x.shape) assert len(input_shape) >= 2, (