未验证 提交 80dc81c5 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Support unsigned] Tensor::data() supports unsigned int and bfloat16 (#50257)

* support unsigned int and bfloat16

* update unit test

* update DenseTensor datatype

* unsupport more datatype of mutable_data(Place)

* fix unittest
上级 ffa2ec19
......@@ -184,20 +184,25 @@ T *Tensor::mutable_data() {
return nullptr;
}
template PADDLE_API float *Tensor::mutable_data<float>();
template PADDLE_API double *Tensor::mutable_data<double>();
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>();
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>();
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>();
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>();
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>();
template PADDLE_API uint16_t *Tensor::mutable_data<uint16_t>();
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>();
template PADDLE_API uint32_t *Tensor::mutable_data<uint32_t>();
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>();
template PADDLE_API uint64_t *Tensor::mutable_data<uint64_t>();
template PADDLE_API phi::dtype::bfloat16 *
Tensor::mutable_data<phi::dtype::bfloat16>();
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>();
template PADDLE_API float *Tensor::mutable_data<float>();
template PADDLE_API double *Tensor::mutable_data<double>();
template PADDLE_API phi::dtype::complex<float>
*Tensor::mutable_data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
*Tensor::mutable_data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>();
template <typename T>
T *Tensor::mutable_data(const Place &place) {
......@@ -217,20 +222,20 @@ T *Tensor::mutable_data(const Place &place) {
return nullptr;
}
template PADDLE_API float *Tensor::mutable_data<float>(const Place &place);
template PADDLE_API double *Tensor::mutable_data<double>(const Place &place);
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>(const Place &place);
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>(const Place &place);
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>(const Place &place);
template PADDLE_API bool *Tensor::mutable_data<bool>(const Place &place);
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(const Place &place);
template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>(const Place &place);
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(const Place &place);
template PADDLE_API bool *Tensor::mutable_data<bool>(const Place &place);
template PADDLE_API int32_t *Tensor::mutable_data<int32_t>(const Place &place);
template PADDLE_API int64_t *Tensor::mutable_data<int64_t>(const Place &place);
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>(const Place &place);
template PADDLE_API float *Tensor::mutable_data<float>(const Place &place);
template PADDLE_API double *Tensor::mutable_data<double>(const Place &place);
template PADDLE_API phi::dtype::complex<float>
*Tensor::mutable_data<phi::dtype::complex<float>>(const Place &place);
template PADDLE_API phi::dtype::complex<double>
*Tensor::mutable_data<phi::dtype::complex<double>>(const Place &place);
template PADDLE_API phi::dtype::float16 *
Tensor::mutable_data<phi::dtype::float16>(const Place &place);
template <typename T>
const T *Tensor::data() const {
......@@ -242,22 +247,25 @@ const T *Tensor::data() const {
return nullptr;
}
template PADDLE_API const float *Tensor::data<float>() const;
template PADDLE_API const double *Tensor::data<double>() const;
template PADDLE_API const int64_t *Tensor::data<int64_t>() const;
template PADDLE_API const int32_t *Tensor::data<int32_t>() const;
template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const;
template PADDLE_API const int8_t *Tensor::data<int8_t>() const;
template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const int16_t *Tensor::data<int16_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const;
template PADDLE_API const uint16_t *Tensor::data<uint16_t>() const;
template PADDLE_API const int32_t *Tensor::data<int32_t>() const;
template PADDLE_API const uint32_t *Tensor::data<uint32_t>() const;
template PADDLE_API const int64_t *Tensor::data<int64_t>() const;
template PADDLE_API const uint64_t *Tensor::data<uint64_t>() const;
template PADDLE_API const phi::dtype::bfloat16 *
Tensor::data<phi::dtype::bfloat16>() const;
template PADDLE_API const phi::dtype::float16 *
Tensor::data<phi::dtype::float16>() const;
template PADDLE_API const float *Tensor::data<float>() const;
template PADDLE_API const double *Tensor::data<double>() const;
template PADDLE_API const phi::dtype::complex<float>
*Tensor::data<phi::dtype::complex<float>>() const;
template PADDLE_API const phi::dtype::complex<double>
*Tensor::data<phi::dtype::complex<double>>() const;
template PADDLE_API const phi::dtype::float16 *
Tensor::data<phi::dtype::float16>() const;
template PADDLE_API const phi::dtype::bfloat16 *
Tensor::data<phi::dtype::bfloat16>() const;
template <typename T>
T *Tensor::data() {
......@@ -271,19 +279,23 @@ T *Tensor::data() {
return nullptr;
}
template PADDLE_API float *Tensor::data<float>();
template PADDLE_API double *Tensor::data<double>();
template PADDLE_API int64_t *Tensor::data<int64_t>();
template PADDLE_API int32_t *Tensor::data<int32_t>();
template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API bool *Tensor::data<bool>();
template PADDLE_API int8_t *Tensor::data<int8_t>();
template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API int16_t *Tensor::data<int16_t>();
template PADDLE_API bool *Tensor::data<bool>();
template PADDLE_API uint16_t *Tensor::data<uint16_t>();
template PADDLE_API int32_t *Tensor::data<int32_t>();
template PADDLE_API uint32_t *Tensor::data<uint32_t>();
template PADDLE_API int64_t *Tensor::data<int64_t>();
template PADDLE_API uint64_t *Tensor::data<uint64_t>();
template PADDLE_API phi::dtype::bfloat16 *Tensor::data<phi::dtype::bfloat16>();
template PADDLE_API phi::dtype::float16 *Tensor::data<phi::dtype::float16>();
template PADDLE_API float *Tensor::data<float>();
template PADDLE_API double *Tensor::data<double>();
template PADDLE_API phi::dtype::complex<float>
*Tensor::data<phi::dtype::complex<float>>();
template PADDLE_API phi::dtype::complex<double>
*Tensor::data<phi::dtype::complex<double>>();
template PADDLE_API phi::dtype::float16 *Tensor::data<phi::dtype::float16>();
// TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
......
......@@ -190,12 +190,15 @@ LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(bool)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int8_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint8_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int16_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint16_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int32_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint32_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int64_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(uint64_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::bfloat16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::float16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::complex<float>)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::phi::dtype::complex<double>)
......
......@@ -180,16 +180,18 @@ void GroupTestCast() {
}
void GroupTestDtype() {
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
CHECK(TestDtype<bool>() == paddle::DataType::BOOL);
CHECK(TestDtype<int8_t>() == paddle::DataType::INT8);
CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int32_t>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<paddle::float16>() == paddle::DataType::FLOAT16);
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
CHECK(TestDtype<paddle::complex64>() == paddle::DataType::COMPLEX64);
CHECK(TestDtype<paddle::complex128>() == paddle::DataType::COMPLEX128);
CHECK(TestDtype<paddle::float16>() == paddle::DataType::FLOAT16);
}
void TestInitilized() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册