From 80dc81c5813e3ce2212839c019969da488557db0 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Wed, 8 Feb 2023 14:53:53 +0800 Subject: [PATCH] [Tensor Support unsigned] Tensor::data() supports unsigned int and bfloat16 (#50257) * support unsigned int and bfloat16 * update unit test * update DenseTensor datatype * unsupport more datatype of mutable_data(Place) * fix unittest --- paddle/phi/api/lib/tensor.cc | 78 ++++++++++++++----------- paddle/phi/core/dense_tensor_impl.cc | 7 ++- paddle/phi/tests/api/test_phi_tensor.cc | 14 +++-- 3 files changed, 58 insertions(+), 41 deletions(-) diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index dfc5c427df7..8985383fd2c 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -184,20 +184,25 @@ T *Tensor::mutable_data() { return nullptr; } -template PADDLE_API float *Tensor::mutable_data(); -template PADDLE_API double *Tensor::mutable_data(); -template PADDLE_API int64_t *Tensor::mutable_data(); -template PADDLE_API int32_t *Tensor::mutable_data(); -template PADDLE_API uint8_t *Tensor::mutable_data(); +template PADDLE_API bool *Tensor::mutable_data(); template PADDLE_API int8_t *Tensor::mutable_data(); +template PADDLE_API uint8_t *Tensor::mutable_data(); template PADDLE_API int16_t *Tensor::mutable_data(); -template PADDLE_API bool *Tensor::mutable_data(); +template PADDLE_API uint16_t *Tensor::mutable_data(); +template PADDLE_API int32_t *Tensor::mutable_data(); +template PADDLE_API uint32_t *Tensor::mutable_data(); +template PADDLE_API int64_t *Tensor::mutable_data(); +template PADDLE_API uint64_t *Tensor::mutable_data(); +template PADDLE_API phi::dtype::bfloat16 * +Tensor::mutable_data(); +template PADDLE_API phi::dtype::float16 * +Tensor::mutable_data(); +template PADDLE_API float *Tensor::mutable_data(); +template PADDLE_API double *Tensor::mutable_data(); template PADDLE_API phi::dtype::complex *Tensor::mutable_data>(); template PADDLE_API phi::dtype::complex *Tensor::mutable_data>(); -template PADDLE_API phi::dtype::float16 * -Tensor::mutable_data(); template T *Tensor::mutable_data(const Place &place) { @@ -217,20 +222,20 @@ T *Tensor::mutable_data(const Place &place) { return nullptr; } -template PADDLE_API float *Tensor::mutable_data(const Place &place); -template PADDLE_API double *Tensor::mutable_data(const Place &place); -template PADDLE_API int64_t *Tensor::mutable_data(const Place &place); -template PADDLE_API int32_t *Tensor::mutable_data(const Place &place); -template PADDLE_API uint8_t *Tensor::mutable_data(const Place &place); +template PADDLE_API bool *Tensor::mutable_data(const Place &place); template PADDLE_API int8_t *Tensor::mutable_data(const Place &place); +template PADDLE_API uint8_t *Tensor::mutable_data(const Place &place); template PADDLE_API int16_t *Tensor::mutable_data(const Place &place); -template PADDLE_API bool *Tensor::mutable_data(const Place &place); +template PADDLE_API int32_t *Tensor::mutable_data(const Place &place); +template PADDLE_API int64_t *Tensor::mutable_data(const Place &place); +template PADDLE_API phi::dtype::float16 * +Tensor::mutable_data(const Place &place); +template PADDLE_API float *Tensor::mutable_data(const Place &place); +template PADDLE_API double *Tensor::mutable_data(const Place &place); template PADDLE_API phi::dtype::complex *Tensor::mutable_data>(const Place &place); template PADDLE_API phi::dtype::complex *Tensor::mutable_data>(const Place &place); -template PADDLE_API phi::dtype::float16 * -Tensor::mutable_data(const Place &place); template const T *Tensor::data() const { @@ -242,22 +247,25 @@ const T *Tensor::data() const { return nullptr; } -template PADDLE_API const float *Tensor::data() const; -template PADDLE_API const double *Tensor::data() const; -template PADDLE_API const int64_t *Tensor::data() const; -template PADDLE_API const int32_t *Tensor::data() const; -template PADDLE_API const uint8_t *Tensor::data() const; +template PADDLE_API const bool *Tensor::data() const; template PADDLE_API const int8_t *Tensor::data() const; +template PADDLE_API const uint8_t *Tensor::data() const; template PADDLE_API const int16_t *Tensor::data() const; -template PADDLE_API const bool *Tensor::data() const; +template PADDLE_API const uint16_t *Tensor::data() const; +template PADDLE_API const int32_t *Tensor::data() const; +template PADDLE_API const uint32_t *Tensor::data() const; +template PADDLE_API const int64_t *Tensor::data() const; +template PADDLE_API const uint64_t *Tensor::data() const; +template PADDLE_API const phi::dtype::bfloat16 * +Tensor::data() const; +template PADDLE_API const phi::dtype::float16 * +Tensor::data() const; +template PADDLE_API const float *Tensor::data() const; +template PADDLE_API const double *Tensor::data() const; template PADDLE_API const phi::dtype::complex *Tensor::data>() const; template PADDLE_API const phi::dtype::complex *Tensor::data>() const; -template PADDLE_API const phi::dtype::float16 * -Tensor::data() const; -template PADDLE_API const phi::dtype::bfloat16 * -Tensor::data() const; template T *Tensor::data() { @@ -271,19 +279,23 @@ T *Tensor::data() { return nullptr; } -template PADDLE_API float *Tensor::data(); -template PADDLE_API double *Tensor::data(); -template PADDLE_API int64_t *Tensor::data(); -template PADDLE_API int32_t *Tensor::data(); -template PADDLE_API uint8_t *Tensor::data(); +template PADDLE_API bool *Tensor::data(); template PADDLE_API int8_t *Tensor::data(); +template PADDLE_API uint8_t *Tensor::data(); template PADDLE_API int16_t *Tensor::data(); -template PADDLE_API bool *Tensor::data(); +template PADDLE_API uint16_t *Tensor::data(); +template PADDLE_API int32_t *Tensor::data(); +template PADDLE_API uint32_t *Tensor::data(); +template PADDLE_API int64_t *Tensor::data(); +template PADDLE_API uint64_t *Tensor::data(); +template PADDLE_API phi::dtype::bfloat16 *Tensor::data(); +template PADDLE_API phi::dtype::float16 *Tensor::data(); +template PADDLE_API float *Tensor::data(); +template PADDLE_API double *Tensor::data(); template PADDLE_API phi::dtype::complex *Tensor::data>(); template PADDLE_API phi::dtype::complex *Tensor::data>(); -template PADDLE_API phi::dtype::float16 *Tensor::data(); // TODO(chenweihang): replace slice impl by API Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const { diff --git a/paddle/phi/core/dense_tensor_impl.cc b/paddle/phi/core/dense_tensor_impl.cc index 39eb608f095..bbe6ef1e773 100644 --- a/paddle/phi/core/dense_tensor_impl.cc +++ b/paddle/phi/core/dense_tensor_impl.cc @@ -190,12 +190,15 @@ LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(bool) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int8_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint8_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int16_t) +LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint16_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int32_t) +LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint32_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int64_t) -LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float) -LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double) +LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint64_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::bfloat16) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::float16) +LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float) +LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::complex) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::complex) diff --git a/paddle/phi/tests/api/test_phi_tensor.cc b/paddle/phi/tests/api/test_phi_tensor.cc index 049aa1c355a..bea36bb0eaf 100644 --- a/paddle/phi/tests/api/test_phi_tensor.cc +++ b/paddle/phi/tests/api/test_phi_tensor.cc @@ -180,16 +180,18 @@ void GroupTestCast() { } void GroupTestDtype() { - CHECK(TestDtype() == paddle::DataType::FLOAT32); - CHECK(TestDtype() == paddle::DataType::FLOAT64); - CHECK(TestDtype() == paddle::DataType::INT32); - CHECK(TestDtype() == paddle::DataType::INT64); - CHECK(TestDtype() == paddle::DataType::INT16); + CHECK(TestDtype() == paddle::DataType::BOOL); CHECK(TestDtype() == paddle::DataType::INT8); CHECK(TestDtype() == paddle::DataType::UINT8); + CHECK(TestDtype() == paddle::DataType::INT16); + CHECK(TestDtype() == paddle::DataType::INT32); + CHECK(TestDtype() == paddle::DataType::INT32); + CHECK(TestDtype() == paddle::DataType::INT64); + CHECK(TestDtype() == paddle::DataType::FLOAT16); + CHECK(TestDtype() == paddle::DataType::FLOAT32); + CHECK(TestDtype() == paddle::DataType::FLOAT64); CHECK(TestDtype() == paddle::DataType::COMPLEX64); CHECK(TestDtype() == paddle::DataType::COMPLEX128); - CHECK(TestDtype() == paddle::DataType::FLOAT16); } void TestInitilized() { -- GitLab