未验证 提交 2a705b74 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.48】为 Paddle determinant 算子实现 float16 数据类型支持 (#53286)

上级 9127cc3c
...@@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(determinant_grad, ...@@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(determinant_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::DeterminantGradKernel, phi::DeterminantGradKernel,
phi::dtype::float16,
float, float,
double) {} double) {}
...@@ -17,5 +17,10 @@ ...@@ -17,5 +17,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(determinant,
determinant, GPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::DeterminantKernel,
phi::dtype::float16,
float,
double) {}
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
#pragma once #pragma once
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/determinant_grad_kernel.h" #include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
...@@ -26,7 +28,6 @@ ...@@ -26,7 +28,6 @@
#include "paddle/phi/kernels/funcs/matrix_inverse.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
namespace phi { namespace phi {
namespace detail { namespace detail {
...@@ -113,6 +114,11 @@ void DeterminantGradKernel(const Context& dev_ctx, ...@@ -113,6 +114,11 @@ void DeterminantGradKernel(const Context& dev_ctx,
return; return;
} }
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto origin_dt = std::is_same<phi::dtype::float16, T>::value
? DataType::FLOAT16
: DataType::BFLOAT16;
// The matrix is invertible // The matrix is invertible
// let |A| = Determinant(A) // let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
...@@ -123,16 +129,22 @@ void DeterminantGradKernel(const Context& dev_ctx, ...@@ -123,16 +129,22 @@ void DeterminantGradKernel(const Context& dev_ctx,
DenseTensor inverse_A; DenseTensor inverse_A;
// A must be square matrices! // A must be square matrices!
inverse_A.Resize(x.dims()); inverse_A.Resize(x.dims());
dev_ctx.template Alloc<T>(&inverse_A); dev_ctx.template Alloc<MPType>(&inverse_A);
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv; phi::funcs::MatrixInverseFunctor<Context, MPType> mat_inv;
if (!std::is_same<MPType, T>::value) {
mat_inv(dev_ctx,
phi::Cast<T, Context>(dev_ctx, x, DataType::FLOAT32),
&inverse_A);
} else {
mat_inv(dev_ctx, x, &inverse_A); mat_inv(dev_ctx, x, &inverse_A);
}
VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1) // Second: inverse(A).transpose(-2, -1)
DenseTensor transpose_inverse_A = DenseTensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A); phi::TransposeLast2Dim<MPType>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: " VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims(); << transpose_inverse_A.dims();
...@@ -147,7 +159,15 @@ void DeterminantGradKernel(const Context& dev_ctx, ...@@ -147,7 +159,15 @@ void DeterminantGradKernel(const Context& dev_ctx,
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims(); VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A) // Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A); DenseTensor res;
if (!std::is_same<MPType, T>::value) {
res = phi::Multiply<T>(
dev_ctx,
unsqueeze2,
phi::Cast<MPType, Context>(dev_ctx, transpose_inverse_A, origin_dt));
} else {
res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
}
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims(); VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
...@@ -31,6 +32,13 @@ namespace detail { ...@@ -31,6 +32,13 @@ namespace detail {
template <typename T> template <typename T>
class EigenMatrix {}; class EigenMatrix {};
template <>
class EigenMatrix<phi::dtype::float16> {
public:
using MatrixType =
Eigen::Matrix<phi::dtype::float16, Eigen::Dynamic, Eigen::Dynamic>;
};
template <> template <>
class EigenMatrix<float> { class EigenMatrix<float> {
public: public:
...@@ -74,6 +82,7 @@ struct DeterminantFunctor { ...@@ -74,6 +82,7 @@ struct DeterminantFunctor {
std::vector<T> input_vec; std::vector<T> input_vec;
std::vector<T> output_vec; std::vector<T> output_vec;
phi::TensorToVector(input, dev_ctx, &input_vec); phi::TensorToVector(input, dev_ctx, &input_vec);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank; auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank; auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
...@@ -85,7 +94,8 @@ struct DeterminantFunctor { ...@@ -85,7 +94,8 @@ struct DeterminantFunctor {
matrix(i, j) = sub_vec[rank * i + j]; matrix(i, j) = sub_vec[rank * i + j];
} }
} }
output_vec.push_back(matrix.determinant()); output_vec.push_back(
static_cast<T>(matrix.template cast<MPType>().determinant()));
} }
phi::TensorFromVector(output_vec, dev_ctx, output); phi::TensorFromVector(output_vec, dev_ctx, output);
} }
......
...@@ -50,6 +50,14 @@ class TestDeterminantOpCase1(TestDeterminantOp): ...@@ -50,6 +50,14 @@ class TestDeterminantOpCase1(TestDeterminantOp):
self.target = np.linalg.det(self.case) self.target = np.linalg.det(self.case)
class TestDeterminantOpCase1FP16(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(10, 10).astype(np.float16)
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case.astype(np.float32))
class TestDeterminantOpCase2(TestDeterminantOp): class TestDeterminantOpCase2(TestDeterminantOp):
def init_data(self): def init_data(self):
np.random.seed(0) np.random.seed(0)
...@@ -59,6 +67,17 @@ class TestDeterminantOpCase2(TestDeterminantOp): ...@@ -59,6 +67,17 @@ class TestDeterminantOpCase2(TestDeterminantOp):
self.target = np.linalg.det(self.case) self.target = np.linalg.det(self.case)
class TestDeterminantOpCase2FP16(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
# not invertible matrix
self.case = np.ones([4, 2, 4, 4]).astype(np.float16)
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case.astype(np.float32)).astype(
np.float16
)
class TestDeterminantAPI(unittest.TestCase): class TestDeterminantAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(0) np.random.seed(0)
......
...@@ -1809,7 +1809,7 @@ def det(x, name=None): ...@@ -1809,7 +1809,7 @@ def det(x, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.det(x) return _C_ops.det(x)
else: else:
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det') check_dtype(x.dtype, 'Input', ['float16', 'float32', 'float64'], 'det')
input_shape = list(x.shape) input_shape = list(x.shape)
assert len(input_shape) >= 2, ( assert len(input_shape) >= 2, (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册