未验证 提交 2a705b74 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.48】为 Paddle determinant 算子实现 float16 数据类型支持 (#53286)

上级 9127cc3c
......@@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(determinant_grad,
GPU,
ALL_LAYOUT,
phi::DeterminantGradKernel,
phi::dtype::float16,
float,
double) {}
......@@ -17,5 +17,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
PD_REGISTER_KERNEL(
determinant, GPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}
PD_REGISTER_KERNEL(determinant,
GPU,
ALL_LAYOUT,
phi::DeterminantKernel,
phi::dtype::float16,
float,
double) {}
......@@ -15,8 +15,10 @@
#pragma once
#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
......@@ -26,7 +28,6 @@
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace detail {
......@@ -113,6 +114,11 @@ void DeterminantGradKernel(const Context& dev_ctx,
return;
}
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto origin_dt = std::is_same<phi::dtype::float16, T>::value
? DataType::FLOAT16
: DataType::BFLOAT16;
// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
......@@ -123,16 +129,22 @@ void DeterminantGradKernel(const Context& dev_ctx,
DenseTensor inverse_A;
// A must be square matrices!
inverse_A.Resize(x.dims());
dev_ctx.template Alloc<T>(&inverse_A);
dev_ctx.template Alloc<MPType>(&inverse_A);
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
mat_inv(dev_ctx, x, &inverse_A);
phi::funcs::MatrixInverseFunctor<Context, MPType> mat_inv;
if (!std::is_same<MPType, T>::value) {
mat_inv(dev_ctx,
phi::Cast<T, Context>(dev_ctx, x, DataType::FLOAT32),
&inverse_A);
} else {
mat_inv(dev_ctx, x, &inverse_A);
}
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1)
DenseTensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
phi::TransposeLast2Dim<MPType>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
......@@ -147,7 +159,15 @@ void DeterminantGradKernel(const Context& dev_ctx,
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
DenseTensor res;
if (!std::is_same<MPType, T>::value) {
res = phi::Multiply<T>(
dev_ctx,
unsqueeze2,
phi::Cast<MPType, Context>(dev_ctx, transpose_inverse_A, origin_dt));
} else {
res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
}
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
......
......@@ -21,6 +21,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -31,6 +32,13 @@ namespace detail {
template <typename T>
class EigenMatrix {};
template <>
class EigenMatrix<phi::dtype::float16> {
public:
using MatrixType =
Eigen::Matrix<phi::dtype::float16, Eigen::Dynamic, Eigen::Dynamic>;
};
template <>
class EigenMatrix<float> {
public:
......@@ -74,6 +82,7 @@ struct DeterminantFunctor {
std::vector<T> input_vec;
std::vector<T> output_vec;
phi::TensorToVector(input, dev_ctx, &input_vec);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
......@@ -85,7 +94,8 @@ struct DeterminantFunctor {
matrix(i, j) = sub_vec[rank * i + j];
}
}
output_vec.push_back(matrix.determinant());
output_vec.push_back(
static_cast<T>(matrix.template cast<MPType>().determinant()));
}
phi::TensorFromVector(output_vec, dev_ctx, output);
}
......
......@@ -50,6 +50,14 @@ class TestDeterminantOpCase1(TestDeterminantOp):
self.target = np.linalg.det(self.case)
class TestDeterminantOpCase1FP16(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(10, 10).astype(np.float16)
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case.astype(np.float32))
class TestDeterminantOpCase2(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
......@@ -59,6 +67,17 @@ class TestDeterminantOpCase2(TestDeterminantOp):
self.target = np.linalg.det(self.case)
class TestDeterminantOpCase2FP16(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
# not invertible matrix
self.case = np.ones([4, 2, 4, 4]).astype(np.float16)
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case.astype(np.float32)).astype(
np.float16
)
class TestDeterminantAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
......
......@@ -1809,7 +1809,7 @@ def det(x, name=None):
if in_dygraph_mode():
return _C_ops.det(x)
else:
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det')
check_dtype(x.dtype, 'Input', ['float16', 'float32', 'float64'], 'det')
input_shape = list(x.shape)
assert len(input_shape) >= 2, (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册