未验证 提交 c05cd7ed 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Clean useless header in pten core (#39560)

* clean useless header in pten core

* fix compiled failed

* fix cmake target

* fix typo

* resolve conflict
上级 9f99b591
...@@ -176,12 +176,12 @@ template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>(); ...@@ -176,12 +176,12 @@ template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(); template PADDLE_API int8_t *Tensor::mutable_data<int8_t>();
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(); template PADDLE_API int16_t *Tensor::mutable_data<int16_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>(); template PADDLE_API bool *Tensor::mutable_data<bool>();
template PADDLE_API paddle::platform::complex<float> template PADDLE_API pten::dtype::complex<float>
*Tensor::mutable_data<paddle::platform::complex<float>>(); *Tensor::mutable_data<pten::dtype::complex<float>>();
template PADDLE_API paddle::platform::complex<double> template PADDLE_API pten::dtype::complex<double>
*Tensor::mutable_data<paddle::platform::complex<double>>(); *Tensor::mutable_data<pten::dtype::complex<double>>();
template PADDLE_API paddle::platform::float16 * template PADDLE_API pten::dtype::float16 *
Tensor::mutable_data<paddle::platform::float16>(); Tensor::mutable_data<pten::dtype::float16>();
template <typename T> template <typename T>
T *Tensor::mutable_data(const PlaceType &place) { T *Tensor::mutable_data(const PlaceType &place) {
...@@ -214,12 +214,12 @@ template PADDLE_API int8_t *Tensor::mutable_data<int8_t>( ...@@ -214,12 +214,12 @@ template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>( template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place); const PlaceType &place);
template PADDLE_API bool *Tensor::mutable_data<bool>(const PlaceType &place); template PADDLE_API bool *Tensor::mutable_data<bool>(const PlaceType &place);
template PADDLE_API paddle::platform::complex<float> * template PADDLE_API pten::dtype::complex<float>
Tensor::mutable_data<paddle::platform::complex<float>>(const PlaceType &place); *Tensor::mutable_data<pten::dtype::complex<float>>(const PlaceType &place);
template PADDLE_API paddle::platform::complex<double> * template PADDLE_API pten::dtype::complex<double>
Tensor::mutable_data<paddle::platform::complex<double>>(const PlaceType &place); *Tensor::mutable_data<pten::dtype::complex<double>>(const PlaceType &place);
template PADDLE_API paddle::platform::float16 * template PADDLE_API pten::dtype::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place); Tensor::mutable_data<pten::dtype::float16>(const PlaceType &place);
template <typename T> template <typename T>
const T *Tensor::data() const { const T *Tensor::data() const {
...@@ -241,14 +241,14 @@ template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const; ...@@ -241,14 +241,14 @@ template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const int8_t *Tensor::data<int8_t>() const; template PADDLE_API const int8_t *Tensor::data<int8_t>() const;
template PADDLE_API const int16_t *Tensor::data<int16_t>() const; template PADDLE_API const int16_t *Tensor::data<int16_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const; template PADDLE_API const bool *Tensor::data<bool>() const;
template PADDLE_API const paddle::platform::complex<float> template PADDLE_API const pten::dtype::complex<float>
*Tensor::data<paddle::platform::complex<float>>() const; *Tensor::data<pten::dtype::complex<float>>() const;
template PADDLE_API const paddle::platform::complex<double> template PADDLE_API const pten::dtype::complex<double>
*Tensor::data<paddle::platform::complex<double>>() const; *Tensor::data<pten::dtype::complex<double>>() const;
template PADDLE_API const paddle::platform::float16 * template PADDLE_API const pten::dtype::float16 *
Tensor::data<paddle::platform::float16>() const; Tensor::data<pten::dtype::float16>() const;
template PADDLE_API const paddle::platform::bfloat16 * template PADDLE_API const pten::dtype::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const; Tensor::data<pten::dtype::bfloat16>() const;
template <typename T> template <typename T>
T *Tensor::data() { T *Tensor::data() {
...@@ -267,12 +267,11 @@ template PADDLE_API uint8_t *Tensor::data<uint8_t>(); ...@@ -267,12 +267,11 @@ template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API int8_t *Tensor::data<int8_t>(); template PADDLE_API int8_t *Tensor::data<int8_t>();
template PADDLE_API int16_t *Tensor::data<int16_t>(); template PADDLE_API int16_t *Tensor::data<int16_t>();
template PADDLE_API bool *Tensor::data<bool>(); template PADDLE_API bool *Tensor::data<bool>();
template PADDLE_API paddle::platform::complex<float> template PADDLE_API pten::dtype::complex<float>
*Tensor::data<paddle::platform::complex<float>>(); *Tensor::data<pten::dtype::complex<float>>();
template PADDLE_API paddle::platform::complex<double> template PADDLE_API pten::dtype::complex<double>
*Tensor::data<paddle::platform::complex<double>>(); *Tensor::data<pten::dtype::complex<double>>();
template PADDLE_API paddle::platform::float16 * template PADDLE_API pten::dtype::float16 *Tensor::data<pten::dtype::float16>();
Tensor::data<paddle::platform::float16>();
// TODO(chenweihang): replace slice impl by API // TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const { Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
...@@ -328,12 +327,12 @@ template PADDLE_API Tensor ...@@ -328,12 +327,12 @@ template PADDLE_API Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const; Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor template PADDLE_API Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const; Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<paddle::platform::complex<float>>( template PADDLE_API Tensor Tensor::copy_to<pten::dtype::complex<float>>(
const PlaceType &target_place) const; const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<paddle::platform::complex<double>>( template PADDLE_API Tensor Tensor::copy_to<pten::dtype::complex<double>>(
const PlaceType &target_place) const; const PlaceType &target_place) const;
template PADDLE_API Tensor template PADDLE_API Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const; Tensor::copy_to<pten::dtype::float16>(const PlaceType &target_place) const;
Tensor Tensor::copy_to(Backend backend, bool blocking) const { Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::copy_to(*this, backend, blocking); return experimental::copy_to(*this, backend, blocking);
......
...@@ -20,11 +20,6 @@ limitations under the License. */ ...@@ -20,11 +20,6 @@ limitations under the License. */
#include "paddle/pten/common/place.h" #include "paddle/pten/common/place.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_meta.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/data_type.h"
// TODO(chenweihang): this file may need to be removed
namespace pten { namespace pten {
std::string TransToPtenKernelName(const std::string& fluid_op_name); std::string TransToPtenKernelName(const std::string& fluid_op_name);
......
...@@ -202,12 +202,12 @@ DATA_MEMBER_FUNC_INSTANTIATION(int32_t); ...@@ -202,12 +202,12 @@ DATA_MEMBER_FUNC_INSTANTIATION(int32_t);
DATA_MEMBER_FUNC_INSTANTIATION(uint32_t); DATA_MEMBER_FUNC_INSTANTIATION(uint32_t);
DATA_MEMBER_FUNC_INSTANTIATION(int64_t); DATA_MEMBER_FUNC_INSTANTIATION(int64_t);
DATA_MEMBER_FUNC_INSTANTIATION(uint64_t); DATA_MEMBER_FUNC_INSTANTIATION(uint64_t);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::bfloat16); DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::bfloat16);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::float16); DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::float16);
DATA_MEMBER_FUNC_INSTANTIATION(float); DATA_MEMBER_FUNC_INSTANTIATION(float);
DATA_MEMBER_FUNC_INSTANTIATION(double); DATA_MEMBER_FUNC_INSTANTIATION(double);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64); DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<float>);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128); DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<double>);
#undef DATA_MEMBER_FUNC_INSTANTIATION #undef DATA_MEMBER_FUNC_INSTANTIATION
......
...@@ -20,9 +20,6 @@ limitations under the License. */ ...@@ -20,9 +20,6 @@ limitations under the License. */
#include "paddle/pten/core/tensor_base.h" #include "paddle/pten/core/tensor_base.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_meta.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/data_type.h"
/* @jim19930609: Move to MKLDNN_Tensor in the future /* @jim19930609: Move to MKLDNN_Tensor in the future
*/ */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -40,14 +40,14 @@ size_t DenseTensor::memory_size() const { ...@@ -40,14 +40,14 @@ size_t DenseTensor::memory_size() const {
} }
void DenseTensor::check_memory_size() const { void DenseTensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(holder_, PADDLE_ENFORCE_NOT_NULL(
paddle::platform::errors::PreconditionNotMet( holder_,
"Tensor holds no memory. " pten::errors::PreconditionNotMet("Tensor holds no memory. "
"Call Tensor::mutable_data firstly.")); "Call Tensor::mutable_data firstly."));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
numel() * SizeOf(dtype()), numel() * SizeOf(dtype()),
memory_size(), memory_size(),
paddle::platform::errors::PreconditionNotMet( pten::errors::PreconditionNotMet(
"Tensor's dimension is out of bound." "Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its " "Tensor's dimension must be equal or less than the size of its "
"memory." "memory."
...@@ -56,10 +56,10 @@ void DenseTensor::check_memory_size() const { ...@@ -56,10 +56,10 @@ void DenseTensor::check_memory_size() const {
memory_size())); memory_size()));
} }
const paddle::platform::Place& DenseTensor::place() const { const Place& DenseTensor::place() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, holder_,
paddle::platform::errors::PreconditionNotMet( pten::errors::PreconditionNotMet(
"Tensor not initialized yet when DenseTensor::place() is called.")); "Tensor not initialized yet when DenseTensor::place() is called."));
return holder_->place(); return holder_->place();
} }
...@@ -82,7 +82,7 @@ void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) { ...@@ -82,7 +82,7 @@ void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) {
numel() * static_cast<int64_t>(SizeOf(dtype())) + numel() * static_cast<int64_t>(SizeOf(dtype())) +
static_cast<int64_t>(meta_.offset), static_cast<int64_t>(meta_.offset),
static_cast<int64_t>(holder->size()), static_cast<int64_t>(holder->size()),
paddle::platform::errors::InvalidArgument( pten::errors::InvalidArgument(
"The size of Holder is not enough to store the Tensor.")); "The size of Holder is not enough to store the Tensor."));
} }
holder_ = holder; holder_ = holder;
...@@ -99,14 +99,14 @@ void DenseTensor::set_type(paddle::experimental::DataType type) { ...@@ -99,14 +99,14 @@ void DenseTensor::set_type(paddle::experimental::DataType type) {
meta_.dtype = type; meta_.dtype = type;
} }
void* DenseTensor::mutable_data(const paddle::platform::Place& place, void* DenseTensor::mutable_data(const Place& place,
paddle::experimental::DataType type, paddle::experimental::DataType type,
size_t requested_size) { size_t requested_size) {
set_type(type); set_type(type);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
numel(), numel(),
0, 0,
paddle::platform::errors::PreconditionNotMet( pten::errors::PreconditionNotMet(
"The Tensor's element number must be equal or greater than zero. " "The Tensor's element number must be equal or greater than zero. "
"The Tensor's shape is [", "The Tensor's shape is [",
dims(), dims(),
...@@ -127,19 +127,18 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place, ...@@ -127,19 +127,18 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
meta_.offset); meta_.offset);
} }
void* DenseTensor::mutable_data(const paddle::platform::Place& place, void* DenseTensor::mutable_data(const Place& place, size_t requested_size) {
size_t requested_size) {
return mutable_data(place, type(), requested_size); return mutable_data(place, type(), requested_size);
} }
void* DenseTensor::mutable_data(const paddle::platform::Place& place, void* DenseTensor::mutable_data(const Place& place,
paddle::experimental::DataType type, paddle::experimental::DataType type,
const pten::Stream& stream) { const pten::Stream& stream) {
set_type(type); set_type(type);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
numel(), numel(),
0, 0,
paddle::platform::errors::PreconditionNotMet( pten::errors::PreconditionNotMet(
"The Tensor's element number must be equal or greater than zero. " "The Tensor's element number must be equal or greater than zero. "
"The Tensor's shape is [", "The Tensor's shape is [",
dims(), dims(),
...@@ -149,7 +148,7 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place, ...@@ -149,7 +148,7 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset || holder_->size() < size + meta_.offset ||
!(paddle::platform::is_gpu_place(place) && !(place.GetType() == pten::AllocationType::GPU &&
paddle::memory::InSameStream(holder_, stream))) { paddle::memory::InSameStream(holder_, stream))) {
holder_.reset(); holder_.reset();
holder_ = paddle::memory::AllocShared(place, size, stream); holder_ = paddle::memory::AllocShared(place, size, stream);
...@@ -166,7 +165,7 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place, ...@@ -166,7 +165,7 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
*/ */
template <typename T> template <typename T>
inline T* DenseTensor::mutable_data(const DDim& dims, inline T* DenseTensor::mutable_data(const DDim& dims,
const paddle::platform::Place& place, const Place& place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
meta_.dims = dims; meta_.dims = dims;
...@@ -174,8 +173,7 @@ inline T* DenseTensor::mutable_data(const DDim& dims, ...@@ -174,8 +173,7 @@ inline T* DenseTensor::mutable_data(const DDim& dims,
} }
template <typename T> template <typename T>
inline T* DenseTensor::mutable_data(const paddle::platform::Place& place, inline T* DenseTensor::mutable_data(const Place& place, size_t requested_size) {
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, mutable_data(place,
...@@ -189,13 +187,11 @@ void DenseTensor::ShareBufferWith(const DenseTensor& tensor) { ...@@ -189,13 +187,11 @@ void DenseTensor::ShareBufferWith(const DenseTensor& tensor) {
meta_.dtype = tensor.dtype(); meta_.dtype = tensor.dtype();
} }
#define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \ #define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DenseTensor::mutable_data( \ template dtype* DenseTensor::mutable_data( \
const DDim& dims, \ const DDim& dims, const Place& place, size_t requested_size); \
const paddle::platform::Place& place, \ template dtype* DenseTensor::mutable_data(const Place& place, \
size_t requested_size); \ size_t requested_size);
template dtype* DenseTensor::mutable_data( \
const paddle::platform::Place& place, size_t requested_size);
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(bool) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(bool)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int8_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int8_t)
...@@ -205,10 +201,10 @@ LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int32_t) ...@@ -205,10 +201,10 @@ LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int32_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int64_t) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int64_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::bfloat16) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::bfloat16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::float16) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::float16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<float>)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128) LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<double>)
#undef LEGACY_DATA_MEMBER_FUNC_INSTANTIATION #undef LEGACY_DATA_MEMBER_FUNC_INSTANTIATION
...@@ -234,7 +230,7 @@ std::pair<size_t, size_t> DenseTensor::lod_element(size_t level, ...@@ -234,7 +230,7 @@ std::pair<size_t, size_t> DenseTensor::lod_element(size_t level,
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
level, level,
NumLevels(), NumLevels(),
paddle::platform::errors::InvalidArgument( pten::errors::InvalidArgument(
"The input level of LoD is invalid, it should be less than LoD " "The input level of LoD is invalid, it should be less than LoD "
"size. The input level is %zu, the LoD size is %zu.", "size. The input level is %zu, the LoD size is %zu.",
level, level,
...@@ -242,7 +238,7 @@ std::pair<size_t, size_t> DenseTensor::lod_element(size_t level, ...@@ -242,7 +238,7 @@ std::pair<size_t, size_t> DenseTensor::lod_element(size_t level,
PADDLE_ENFORCE_LT(elem, PADDLE_ENFORCE_LT(elem,
NumElements(level), NumElements(level),
paddle::platform::errors::InvalidArgument( pten::errors::InvalidArgument(
"The input element of LoD is invalid, it should be " "The input element of LoD is invalid, it should be "
"less than the number of elements in its level." "less than the number of elements in its level."
"The input element is %zu, the number of elements in " "The input element is %zu, the number of elements in "
...@@ -259,7 +255,7 @@ size_t DenseTensor::NumElements(size_t level) const { ...@@ -259,7 +255,7 @@ size_t DenseTensor::NumElements(size_t level) const {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
level, level,
NumLevels(), NumLevels(),
paddle::platform::errors::InvalidArgument( pten::errors::InvalidArgument(
"The input level of LoD is invalid, it should be less than LoD " "The input level of LoD is invalid, it should be less than LoD "
"size. The input level is %zu, the LoD size is %zu.", "size. The input level is %zu, the LoD size is %zu.",
level, level,
...@@ -276,20 +272,20 @@ DenseTensor& DenseTensor::Resize(const DDim& dims) { ...@@ -276,20 +272,20 @@ DenseTensor& DenseTensor::Resize(const DDim& dims) {
DenseTensor DenseTensor::Slice(int64_t begin_idx, int64_t end_idx) const { DenseTensor DenseTensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, PADDLE_ENFORCE_GE(
0, begin_idx,
paddle::platform::errors::OutOfRange( 0,
"The start row index must be greater than 0." pten::errors::OutOfRange("The start row index must be greater than 0."
"But received the start index is d%.", "But received the start index is d%.",
begin_idx)); begin_idx));
PADDLE_ENFORCE_LE(end_idx, PADDLE_ENFORCE_LE(
meta_.dims[0], end_idx,
paddle::platform::errors::OutOfRange( meta_.dims[0],
"The end row index is out of bound.")); pten::errors::OutOfRange("The end row index is out of bound."));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
begin_idx, begin_idx,
end_idx, end_idx,
paddle::platform::errors::InvalidArgument( pten::errors::InvalidArgument(
"The start row index must be less than the end row index." "The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.", "But received the start index = %d, the end index = %d.",
begin_idx, begin_idx,
...@@ -317,13 +313,13 @@ std::vector<DenseTensor> DenseTensor::Split(int64_t split_size, ...@@ -317,13 +313,13 @@ std::vector<DenseTensor> DenseTensor::Split(int64_t split_size,
PADDLE_ENFORCE_GE(meta_.dims.size(), PADDLE_ENFORCE_GE(meta_.dims.size(),
0, 0,
paddle::platform::errors::OutOfRange( pten::errors::OutOfRange(
"split expects at least a 1-dimensional tensor")); "split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
split_size, split_size,
0, 0,
paddle::platform::errors::OutOfRange( pten::errors::OutOfRange(
"split expects split_size be non-negative, but got split_size is %d", "split expects split_size be non-negative, but got split_size is %d",
split_size)); split_size));
...@@ -350,12 +346,12 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks, ...@@ -350,12 +346,12 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
check_memory_size(); check_memory_size();
PADDLE_ENFORCE_GE(meta_.dims.size(), PADDLE_ENFORCE_GE(meta_.dims.size(),
0, 0,
paddle::platform::errors::OutOfRange( pten::errors::OutOfRange(
"split expects at least a 1-dimensional tensor")); "split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
chunks, chunks,
0, 0,
paddle::platform::errors::OutOfRange( pten::errors::OutOfRange(
"chunks expects to be greater than 0, but got chunks is %d", chunks)); "chunks expects to be greater than 0, but got chunks is %d", chunks));
int64_t numel_size = meta_.dims[axis]; int64_t numel_size = meta_.dims[axis];
...@@ -376,7 +372,7 @@ DenseTensor& DenseTensor::ShareInplaceVersionCounterWith( ...@@ -376,7 +372,7 @@ DenseTensor& DenseTensor::ShareInplaceVersionCounterWith(
const DenseTensor& src) { const DenseTensor& src) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_, inplace_version_counter_,
paddle::platform::errors::PreconditionNotMet( pten::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_.")); "Tensor does not hold inplace_version_counter_."));
inplace_version_counter_ = src.inplace_version_counter_; inplace_version_counter_ = src.inplace_version_counter_;
......
...@@ -233,7 +233,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -233,7 +233,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(pten::dtype::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
......
...@@ -26,23 +26,23 @@ namespace pten { ...@@ -26,23 +26,23 @@ namespace pten {
#define _PtenForEachDataTypeHelper_(callback, cpp_type, data_type) \ #define _PtenForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type); callback(cpp_type, data_type);
#define _PtenForEachDataType_(callback) \ #define _PtenForEachDataType_(callback) \
_PtenForEachDataTypeHelper_(callback, float, DataType::FLOAT32); \ _PtenForEachDataTypeHelper_(callback, float, DataType::FLOAT32); \
_PtenForEachDataTypeHelper_( \ _PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::float16, DataType::FLOAT16); \ callback, ::pten::dtype::float16, DataType::FLOAT16); \
_PtenForEachDataTypeHelper_( \ _PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::bfloat16, DataType::BFLOAT16); \ callback, ::pten::dtype::bfloat16, DataType::BFLOAT16); \
_PtenForEachDataTypeHelper_(callback, double, DataType::FLOAT64); \ _PtenForEachDataTypeHelper_(callback, double, DataType::FLOAT64); \
_PtenForEachDataTypeHelper_(callback, int, DataType::INT32); \ _PtenForEachDataTypeHelper_(callback, int, DataType::INT32); \
_PtenForEachDataTypeHelper_(callback, int64_t, DataType::INT64); \ _PtenForEachDataTypeHelper_(callback, int64_t, DataType::INT64); \
_PtenForEachDataTypeHelper_(callback, bool, DataType::BOOL); \ _PtenForEachDataTypeHelper_(callback, bool, DataType::BOOL); \
_PtenForEachDataTypeHelper_(callback, uint8_t, DataType::UINT8); \ _PtenForEachDataTypeHelper_(callback, uint8_t, DataType::UINT8); \
_PtenForEachDataTypeHelper_(callback, int16_t, DataType::INT16); \ _PtenForEachDataTypeHelper_(callback, int16_t, DataType::INT16); \
_PtenForEachDataTypeHelper_(callback, int8_t, DataType::INT8); \ _PtenForEachDataTypeHelper_(callback, int8_t, DataType::INT8); \
_PtenForEachDataTypeHelper_( \ _PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::complex<float>, DataType::COMPLEX64); \ callback, ::pten::dtype::complex<float>, DataType::COMPLEX64); \
_PtenForEachDataTypeHelper_( \ _PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::complex<double>, DataType::COMPLEX128); callback, ::pten::dtype::complex<double>, DataType::COMPLEX128);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(pten::DataType type, Visitor visitor) { inline void VisitDataType(pten::DataType type, Visitor visitor) {
......
...@@ -15,14 +15,10 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function) ...@@ -15,14 +15,10 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
set(MATH_KERNEL_DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel pten_transpose_cpu)
if(WITH_GPU OR WITH_ROCM)
set(MATH_KERNEL_DEPS ${MATH_KERNEL_DEPS} pten_transpose_gpu)
endif()
# auto build kernel targets by cmake # auto build kernel targets by cmake
register_kernels(EXCLUDES math_kernel DEPS ${COMMON_KERNEL_DEPS}) register_kernels(DEPS ${COMMON_KERNEL_DEPS})
kernel_library(math_kernel DEPS ${MATH_KERNEL_DEPS})
# pten sparse kernels
add_subdirectory(sparse) add_subdirectory(sparse)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
...@@ -25,12 +25,12 @@ template <typename T, typename Context> ...@@ -25,12 +25,12 @@ template <typename T, typename Context>
void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
// If T is complex // If T is complex
template <typename T, template <
typename Context, typename T,
std::enable_if_t< typename Context,
std::is_same<T, paddle::platform::complex<float>>::value || std::enable_if_t<std::is_same<T, pten::dtype::complex<float>>::value ||
std::is_same<T, paddle::platform::complex<double>>::value, std::is_same<T, pten::dtype::complex<double>>::value,
bool> = true> bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto dense_out = pten::Empty<T, Context>(dev_ctx); auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
...@@ -40,12 +40,12 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { ...@@ -40,12 +40,12 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
} }
// If T is not complex // If T is not complex
template <typename T, template <
typename Context, typename T,
std::enable_if_t< typename Context,
!std::is_same<T, paddle::platform::complex<float>>::value && std::enable_if_t<!std::is_same<T, pten::dtype::complex<float>>::value &&
!std::is_same<T, paddle::platform::complex<double>>::value, !std::is_same<T, pten::dtype::complex<double>>::value,
bool> = true> bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return x; return x;
} }
......
...@@ -69,9 +69,9 @@ PT_REGISTER_KERNEL(cast, ...@@ -69,9 +69,9 @@ PT_REGISTER_KERNEL(cast,
int16_t, int16_t,
bool, bool,
uint8_t, uint8_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) { pten::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
...@@ -25,8 +25,8 @@ PT_REGISTER_KERNEL(conj, ...@@ -25,8 +25,8 @@ PT_REGISTER_KERNEL(conj,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ConjKernel, pten::ConjKernel,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>, pten::dtype::complex<double>,
float, float,
double, double,
int, int,
......
...@@ -120,6 +120,6 @@ PT_REGISTER_KERNEL(concat, ...@@ -120,6 +120,6 @@ PT_REGISTER_KERNEL(concat,
int64_t, int64_t,
int, int,
uint8_t, uint8_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -28,5 +28,5 @@ PT_REGISTER_KERNEL(dot_grad, ...@@ -28,5 +28,5 @@ PT_REGISTER_KERNEL(dot_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -46,8 +46,8 @@ void DotKernel(const Context& dev_ctx, ...@@ -46,8 +46,8 @@ void DotKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::pten::dtype::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::pten::dtype::complex<double>;
PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CPU, CPU,
......
...@@ -134,8 +134,8 @@ PT_REGISTER_KERNEL(add_grad, ...@@ -134,8 +134,8 @@ PT_REGISTER_KERNEL(add_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(add_double_grad, PT_REGISTER_KERNEL(add_double_grad,
CPU, CPU,
...@@ -145,8 +145,8 @@ PT_REGISTER_KERNEL(add_double_grad, ...@@ -145,8 +145,8 @@ PT_REGISTER_KERNEL(add_double_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(add_triple_grad, PT_REGISTER_KERNEL(add_triple_grad,
CPU, CPU,
...@@ -156,8 +156,8 @@ PT_REGISTER_KERNEL(add_triple_grad, ...@@ -156,8 +156,8 @@ PT_REGISTER_KERNEL(add_triple_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(subtract_grad, PT_REGISTER_KERNEL(subtract_grad,
CPU, CPU,
...@@ -167,8 +167,8 @@ PT_REGISTER_KERNEL(subtract_grad, ...@@ -167,8 +167,8 @@ PT_REGISTER_KERNEL(subtract_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(subtract_double_grad, PT_REGISTER_KERNEL(subtract_double_grad,
CPU, CPU,
...@@ -178,5 +178,5 @@ PT_REGISTER_KERNEL(subtract_double_grad, ...@@ -178,5 +178,5 @@ PT_REGISTER_KERNEL(subtract_double_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -29,10 +29,10 @@ PT_REGISTER_KERNEL(full, ...@@ -29,10 +29,10 @@ PT_REGISTER_KERNEL(full,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CPU, CPU,
...@@ -43,4 +43,4 @@ PT_REGISTER_KERNEL(full_like, ...@@ -43,4 +43,4 @@ PT_REGISTER_KERNEL(full_like,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16) {} pten::dtype::float16) {}
...@@ -113,11 +113,11 @@ DEFINE_CPU_ELEMENTWISE_OP(Multiply) ...@@ -113,11 +113,11 @@ DEFINE_CPU_ELEMENTWISE_OP(Multiply)
} // namespace pten } // namespace pten
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::pten::dtype::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::pten::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16; // using bfloat16 = ::pten::dtype::bfloat16;
PT_REGISTER_KERNEL(add_raw, PT_REGISTER_KERNEL(add_raw,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -166,7 +166,7 @@ PT_REGISTER_KERNEL(sum_raw, ...@@ -166,7 +166,7 @@ PT_REGISTER_KERNEL(sum_raw,
bool, bool,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
int, int,
int64_t, int64_t,
complex64, complex64,
......
...@@ -25,8 +25,8 @@ PT_REGISTER_KERNEL(matmul_grad, ...@@ -25,8 +25,8 @@ PT_REGISTER_KERNEL(matmul_grad,
pten::MatmulGradKernel, pten::MatmulGradKernel,
float, float,
double, double,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(matmul_double_grad, PT_REGISTER_KERNEL(matmul_double_grad,
CPU, CPU,
...@@ -34,8 +34,8 @@ PT_REGISTER_KERNEL(matmul_double_grad, ...@@ -34,8 +34,8 @@ PT_REGISTER_KERNEL(matmul_double_grad,
pten::MatmulDoubleGradKernel, pten::MatmulDoubleGradKernel,
float, float,
double, double,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(matmul_triple_grad, PT_REGISTER_KERNEL(matmul_triple_grad,
CPU, CPU,
...@@ -43,5 +43,5 @@ PT_REGISTER_KERNEL(matmul_triple_grad, ...@@ -43,5 +43,5 @@ PT_REGISTER_KERNEL(matmul_triple_grad,
pten::MatmulTripleGradKernel, pten::MatmulTripleGradKernel,
float, float,
double, double,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -26,5 +26,5 @@ PT_REGISTER_KERNEL(matmul, ...@@ -26,5 +26,5 @@ PT_REGISTER_KERNEL(matmul,
pten::MatmulKernel, pten::MatmulKernel,
float, float,
double, double,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/eigen/common.h" #include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/transpose.h" #include "paddle/pten/kernels/funcs/math_function.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
namespace pten { namespace pten {
...@@ -80,7 +80,7 @@ void ReduceFunctor(const DeviceContext& context, ...@@ -80,7 +80,7 @@ void ReduceFunctor(const DeviceContext& context,
inline void GetShuffledDim(const DDim& src_dims, inline void GetShuffledDim(const DDim& src_dims,
DDim* dst_dims, DDim* dst_dims,
const std::vector<int64_t>& reduced_dims, const std::vector<int64_t>& reduced_dims,
std::vector<int64_t>* perm_axis) { std::vector<int>* perm_axis) {
// check if it's a reduced dim // check if it's a reduced dim
std::vector<bool> src_dims_check(src_dims.size(), false); std::vector<bool> src_dims_check(src_dims.size(), false);
size_t src_size = src_dims.size(); size_t src_size = src_dims.size();
...@@ -115,13 +115,13 @@ void GetShuffledInput(const DeviceContext& dev_ctx, ...@@ -115,13 +115,13 @@ void GetShuffledInput(const DeviceContext& dev_ctx,
pten::DenseTensor* shuffled_input, pten::DenseTensor* shuffled_input,
const std::vector<int64_t>& dims) { const std::vector<int64_t>& dims) {
DDim shuffled_dims(input.dims()); DDim shuffled_dims(input.dims());
std::vector<int64_t> perm_axis(input.dims().size()); std::vector<int> perm_axis(input.dims().size());
GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis); GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis);
shuffled_input->ResizeAndAllocate(shuffled_dims); shuffled_input->ResizeAndAllocate(shuffled_dims);
dev_ctx.template Alloc<OutT>(shuffled_input); dev_ctx.template Alloc<OutT>(shuffled_input);
pten::math::TransposeNormal<DeviceContext, OutT> trans; pten::funcs::TransposeNormal<DeviceContext, OutT> trans;
trans(dev_ctx, input, shuffled_input, perm_axis); trans(dev_ctx, input, shuffled_input, perm_axis);
} }
......
...@@ -45,10 +45,10 @@ PT_REGISTER_KERNEL(empty, ...@@ -45,10 +45,10 @@ PT_REGISTER_KERNEL(empty,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(empty_like, PT_REGISTER_KERNEL(empty_like,
CPU, CPU,
...@@ -61,10 +61,10 @@ PT_REGISTER_KERNEL(empty_like, ...@@ -61,10 +61,10 @@ PT_REGISTER_KERNEL(empty_like,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(empty, PT_REGISTER_KERNEL(empty,
...@@ -78,9 +78,9 @@ PT_REGISTER_KERNEL(empty, ...@@ -78,9 +78,9 @@ PT_REGISTER_KERNEL(empty,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(empty_like, PT_REGISTER_KERNEL(empty_like,
GPU, GPU,
...@@ -93,8 +93,8 @@ PT_REGISTER_KERNEL(empty_like, ...@@ -93,8 +93,8 @@ PT_REGISTER_KERNEL(empty_like,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
#endif #endif
...@@ -49,7 +49,7 @@ PT_REGISTER_KERNEL(flatten_grad, ...@@ -49,7 +49,7 @@ PT_REGISTER_KERNEL(flatten_grad,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
float, float,
paddle::platform::float16, pten::dtype::float16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
...@@ -64,7 +64,7 @@ PT_REGISTER_KERNEL(flatten_grad, ...@@ -64,7 +64,7 @@ PT_REGISTER_KERNEL(flatten_grad,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
float, float,
paddle::platform::float16, pten::dtype::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
......
...@@ -76,7 +76,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -76,7 +76,7 @@ PT_REGISTER_KERNEL(flatten,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
float, float,
paddle::platform::float16, pten::dtype::float16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
...@@ -88,7 +88,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -88,7 +88,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
paddle::platform::float16, pten::dtype::float16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
...@@ -102,7 +102,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -102,7 +102,7 @@ PT_REGISTER_KERNEL(flatten,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
float, float,
paddle::platform::float16, pten::dtype::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
...@@ -112,7 +112,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -112,7 +112,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
paddle::platform::float16, pten::dtype::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
......
add_subdirectory(eigen) add_subdirectory(eigen)
cc_library(pten_transpose_cpu SRCS transpose.cc DEPS dense_tensor pten_context)
if(WITH_GPU)
nv_library(pten_transpose_gpu SRCS transpose.cu DEPS dense_tensor malloc pten_context)
elseif(WITH_ROCM)
hip_library(pten_transpose_gpu SRCS transpose.cu DEPS dense_tensor malloc pten_context)
endif()
function(math_library TARGET) function(math_library TARGET)
# math_library is a function to create math library. # math_library is a function to create math library.
# The interface is the same as cc_library. # The interface is the same as cc_library.
...@@ -47,10 +40,3 @@ function(math_library TARGET) ...@@ -47,10 +40,3 @@ function(math_library TARGET)
endfunction() endfunction()
math_library(math_function DEPS blas dense_tensor tensor) math_library(math_function DEPS blas dense_tensor tensor)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
endif()
if(WITH_ROCM)
hip_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
endif()
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
namespace pten { namespace pten {
namespace funcs { namespace funcs {
......
...@@ -36,12 +36,12 @@ limitations under the License. */ ...@@ -36,12 +36,12 @@ limitations under the License. */
namespace pten { namespace pten {
namespace funcs { namespace funcs {
using float16 = paddle::platform::float16; using float16 = pten::dtype::float16;
template struct SetConstant<paddle::platform::CPUDeviceContext, template struct SetConstant<paddle::platform::CPUDeviceContext,
paddle::platform::float16>; pten::dtype::float16>;
template struct SetConstant<paddle::platform::CPUDeviceContext, template struct SetConstant<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>; pten::dtype::bfloat16>;
template struct SetConstant<paddle::platform::CPUDeviceContext, float>; template struct SetConstant<paddle::platform::CPUDeviceContext, float>;
template struct SetConstant<paddle::platform::CPUDeviceContext, double>; template struct SetConstant<paddle::platform::CPUDeviceContext, double>;
template struct SetConstant<paddle::platform::CPUDeviceContext, int16_t>; template struct SetConstant<paddle::platform::CPUDeviceContext, int16_t>;
...@@ -50,12 +50,12 @@ template struct SetConstant<paddle::platform::CPUDeviceContext, int64_t>; ...@@ -50,12 +50,12 @@ template struct SetConstant<paddle::platform::CPUDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CPUDeviceContext, bool>; template struct SetConstant<paddle::platform::CPUDeviceContext, bool>;
template struct SetConstant<paddle::platform::CPUDeviceContext, uint8_t>; template struct SetConstant<paddle::platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::CPUDeviceContext, template struct SetConstant<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>; pten::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CPUDeviceContext, template struct SetConstant<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>; pten::dtype::complex<double>>;
template struct SetConstant<pten::CPUContext, paddle::platform::float16>; template struct SetConstant<pten::CPUContext, pten::dtype::float16>;
template struct SetConstant<pten::CPUContext, paddle::platform::bfloat16>; template struct SetConstant<pten::CPUContext, pten::dtype::bfloat16>;
template struct SetConstant<pten::CPUContext, float>; template struct SetConstant<pten::CPUContext, float>;
template struct SetConstant<pten::CPUContext, double>; template struct SetConstant<pten::CPUContext, double>;
template struct SetConstant<pten::CPUContext, int16_t>; template struct SetConstant<pten::CPUContext, int16_t>;
...@@ -63,15 +63,14 @@ template struct SetConstant<pten::CPUContext, int>; ...@@ -63,15 +63,14 @@ template struct SetConstant<pten::CPUContext, int>;
template struct SetConstant<pten::CPUContext, int64_t>; template struct SetConstant<pten::CPUContext, int64_t>;
template struct SetConstant<pten::CPUContext, bool>; template struct SetConstant<pten::CPUContext, bool>;
template struct SetConstant<pten::CPUContext, uint8_t>; template struct SetConstant<pten::CPUContext, uint8_t>;
template struct SetConstant<pten::CPUContext, paddle::platform::complex<float>>; template struct SetConstant<pten::CPUContext, pten::dtype::complex<float>>;
template struct SetConstant<pten::CPUContext, template struct SetConstant<pten::CPUContext, pten::dtype::complex<double>>;
paddle::platform::complex<double>>;
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
template struct SetConstant<paddle::platform::XPUDeviceContext, template struct SetConstant<paddle::platform::XPUDeviceContext,
paddle::platform::float16>; pten::dtype::float16>;
template struct SetConstant<paddle::platform::XPUDeviceContext, template struct SetConstant<paddle::platform::XPUDeviceContext,
paddle::platform::bfloat16>; pten::dtype::bfloat16>;
template struct SetConstant<paddle::platform::XPUDeviceContext, float>; template struct SetConstant<paddle::platform::XPUDeviceContext, float>;
template struct SetConstant<paddle::platform::XPUDeviceContext, double>; template struct SetConstant<paddle::platform::XPUDeviceContext, double>;
template struct SetConstant<paddle::platform::XPUDeviceContext, uint8_t>; template struct SetConstant<paddle::platform::XPUDeviceContext, uint8_t>;
...@@ -80,17 +79,17 @@ template struct SetConstant<paddle::platform::XPUDeviceContext, int>; ...@@ -80,17 +79,17 @@ template struct SetConstant<paddle::platform::XPUDeviceContext, int>;
template struct SetConstant<paddle::platform::XPUDeviceContext, int64_t>; template struct SetConstant<paddle::platform::XPUDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::XPUDeviceContext, bool>; template struct SetConstant<paddle::platform::XPUDeviceContext, bool>;
template struct SetConstant<paddle::platform::XPUDeviceContext, template struct SetConstant<paddle::platform::XPUDeviceContext,
paddle::platform::complex<float>>; pten::dtype::complex<float>>;
template struct SetConstant<paddle::platform::XPUDeviceContext, template struct SetConstant<paddle::platform::XPUDeviceContext,
paddle::platform::complex<double>>; pten::dtype::complex<double>>;
#endif #endif
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<paddle::platform::CPUDeviceContext, \ template struct Transpose<paddle::platform::CPUDeviceContext, \
paddle::platform::float16, \ pten::dtype::float16, \
RANK>; \ RANK>; \
template struct Transpose<paddle::platform::CPUDeviceContext, \ template struct Transpose<paddle::platform::CPUDeviceContext, \
paddle::platform::bfloat16, \ pten::dtype::bfloat16, \
RANK>; \ RANK>; \
template struct Transpose<paddle::platform::CPUDeviceContext, float, RANK>; \ template struct Transpose<paddle::platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<paddle::platform::CPUDeviceContext, double, RANK>; \ template struct Transpose<paddle::platform::CPUDeviceContext, double, RANK>; \
...@@ -107,10 +106,26 @@ template struct SetConstant<paddle::platform::XPUDeviceContext, ...@@ -107,10 +106,26 @@ template struct SetConstant<paddle::platform::XPUDeviceContext,
RANK>; \ RANK>; \
template struct Transpose<paddle::platform::CPUDeviceContext, int8_t, RANK>; \ template struct Transpose<paddle::platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<paddle::platform::CPUDeviceContext, \ template struct Transpose<paddle::platform::CPUDeviceContext, \
paddle::platform::complex<float>, \ pten::dtype::complex<float>, \
RANK>; \ RANK>; \
template struct Transpose<paddle::platform::CPUDeviceContext, \ template struct Transpose<paddle::platform::CPUDeviceContext, \
paddle::platform::complex<double>, \ pten::dtype::complex<double>, \
RANK>; \
template struct Transpose<pten::CPUContext, pten::dtype::float16, RANK>; \
template struct Transpose<pten::CPUContext, pten::dtype::bfloat16, RANK>; \
template struct Transpose<pten::CPUContext, float, RANK>; \
template struct Transpose<pten::CPUContext, double, RANK>; \
template struct Transpose<pten::CPUContext, int, RANK>; \
template struct Transpose<pten::CPUContext, int64_t, RANK>; \
template struct Transpose<pten::CPUContext, bool, RANK>; \
template struct Transpose<pten::CPUContext, int16_t, RANK>; \
template struct Transpose<pten::CPUContext, uint8_t, RANK>; \
template struct Transpose<pten::CPUContext, int8_t, RANK>; \
template struct Transpose<pten::CPUContext, \
pten::dtype::complex<float>, \
RANK>; \
template struct Transpose<pten::CPUContext, \
pten::dtype::complex<double>, \
RANK>; RANK>;
DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(1);
...@@ -120,41 +135,41 @@ DEFINE_CPU_TRANS(4); ...@@ -120,41 +135,41 @@ DEFINE_CPU_TRANS(4);
DEFINE_CPU_TRANS(5); DEFINE_CPU_TRANS(5);
DEFINE_CPU_TRANS(6); DEFINE_CPU_TRANS(6);
template <typename T> template <typename DeviceContext, typename T>
struct TransposeNormal<paddle::platform::CPUDeviceContext, T> { void TransposeNormal<DeviceContext, T>::operator()(
void operator()(const paddle::platform::CPUDeviceContext& context, const DeviceContext& context,
const paddle::framework::Tensor& in, const paddle::framework::Tensor& in,
paddle::framework::Tensor* out, paddle::framework::Tensor* out,
const std::vector<int>& axis) { const std::vector<int>& axis) {
const int rank = axis.size(); const int rank = axis.size();
auto in_stride = paddle::framework::stride(in.dims()); auto in_stride = paddle::framework::stride(in.dims());
auto out_stride = paddle::framework::stride(out->dims()); auto out_stride = paddle::framework::stride(out->dims());
const T* in_ptr = in.data<T>(); const T* in_ptr = in.data<T>();
T* out_ptr = out->data<T>(); T* out_ptr = out->data<T>();
auto transpose_helper = [&](int64_t beg, int64_t end) { auto transpose_helper = [&](int64_t beg, int64_t end) {
for (int64_t out_idx = beg; out_idx < end; ++out_idx) { for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
int64_t in_idx = 0; int64_t in_idx = 0;
int64_t tmp_idx = out_idx; int64_t tmp_idx = out_idx;
// calculate the input index // calculate the input index
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride[i]; const int64_t coordinate = tmp_idx / out_stride[i];
tmp_idx -= coordinate * out_stride[i]; tmp_idx -= coordinate * out_stride[i];
in_idx += coordinate * in_stride[axis[i]]; in_idx += coordinate * in_stride[axis[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
} }
}; out_ptr[out_idx] = in_ptr[in_idx];
transpose_helper(0, out->numel()); }
} };
}; transpose_helper(0, out->numel());
}
// define transpose normal // define transpose normal
#define DEFINE_CPU_TRANS_NORMAL(TYPE) \ #define DEFINE_CPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<paddle::platform::CPUDeviceContext, TYPE> template struct TransposeNormal<paddle::platform::CPUDeviceContext, TYPE>; \
template struct TransposeNormal<pten::CPUContext, TYPE>
DEFINE_CPU_TRANS_NORMAL(paddle::platform::float16); DEFINE_CPU_TRANS_NORMAL(pten::dtype::float16);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::bfloat16); DEFINE_CPU_TRANS_NORMAL(pten::dtype::bfloat16);
DEFINE_CPU_TRANS_NORMAL(float); DEFINE_CPU_TRANS_NORMAL(float);
DEFINE_CPU_TRANS_NORMAL(double); DEFINE_CPU_TRANS_NORMAL(double);
DEFINE_CPU_TRANS_NORMAL(int); DEFINE_CPU_TRANS_NORMAL(int);
...@@ -163,8 +178,8 @@ DEFINE_CPU_TRANS_NORMAL(bool); ...@@ -163,8 +178,8 @@ DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t); DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t); DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t); DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<float>); DEFINE_CPU_TRANS_NORMAL(pten::dtype::complex<float>);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<double>); DEFINE_CPU_TRANS_NORMAL(pten::dtype::complex<double>);
struct TensorSetConstantCPU { struct TensorSetConstantCPU {
TensorSetConstantCPU(paddle::framework::Tensor* tensor, float value) TensorSetConstantCPU(paddle::framework::Tensor* tensor, float value)
...@@ -343,7 +358,7 @@ struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> { ...@@ -343,7 +358,7 @@ struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> {
}; };
template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
paddle::platform::float16>; pten::dtype::float16>;
} // namespace funcs } // namespace funcs
} // namespace pten } // namespace pten
...@@ -27,13 +27,13 @@ limitations under the License. */ ...@@ -27,13 +27,13 @@ limitations under the License. */
namespace pten { namespace pten {
namespace funcs { namespace funcs {
using float16 = paddle::platform::float16; using float16 = pten::dtype::float16;
using bfloat16 = paddle::platform::bfloat16; using bfloat16 = pten::dtype::bfloat16;
template struct SetConstant<paddle::platform::CUDADeviceContext, template struct SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::float16>; pten::dtype::float16>;
template struct SetConstant<paddle::platform::CUDADeviceContext, template struct SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>; pten::dtype::bfloat16>;
template struct SetConstant<paddle::platform::CUDADeviceContext, float>; template struct SetConstant<paddle::platform::CUDADeviceContext, float>;
template struct SetConstant<paddle::platform::CUDADeviceContext, double>; template struct SetConstant<paddle::platform::CUDADeviceContext, double>;
template struct SetConstant<paddle::platform::CUDADeviceContext, uint8_t>; template struct SetConstant<paddle::platform::CUDADeviceContext, uint8_t>;
...@@ -42,12 +42,12 @@ template struct SetConstant<paddle::platform::CUDADeviceContext, int16_t>; ...@@ -42,12 +42,12 @@ template struct SetConstant<paddle::platform::CUDADeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CUDADeviceContext, int64_t>; template struct SetConstant<paddle::platform::CUDADeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CUDADeviceContext, bool>; template struct SetConstant<paddle::platform::CUDADeviceContext, bool>;
template struct SetConstant<paddle::platform::CUDADeviceContext, template struct SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>; pten::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDADeviceContext, template struct SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>; pten::dtype::complex<double>>;
template struct SetConstant<pten::GPUContext, paddle::platform::float16>; template struct SetConstant<pten::GPUContext, pten::dtype::float16>;
template struct SetConstant<pten::GPUContext, paddle::platform::bfloat16>; template struct SetConstant<pten::GPUContext, pten::dtype::bfloat16>;
template struct SetConstant<pten::GPUContext, float>; template struct SetConstant<pten::GPUContext, float>;
template struct SetConstant<pten::GPUContext, double>; template struct SetConstant<pten::GPUContext, double>;
template struct SetConstant<pten::GPUContext, uint8_t>; template struct SetConstant<pten::GPUContext, uint8_t>;
...@@ -55,14 +55,13 @@ template struct SetConstant<pten::GPUContext, int>; ...@@ -55,14 +55,13 @@ template struct SetConstant<pten::GPUContext, int>;
template struct SetConstant<pten::GPUContext, int16_t>; template struct SetConstant<pten::GPUContext, int16_t>;
template struct SetConstant<pten::GPUContext, int64_t>; template struct SetConstant<pten::GPUContext, int64_t>;
template struct SetConstant<pten::GPUContext, bool>; template struct SetConstant<pten::GPUContext, bool>;
template struct SetConstant<pten::GPUContext, paddle::platform::complex<float>>; template struct SetConstant<pten::GPUContext, pten::dtype::complex<float>>;
template struct SetConstant<pten::GPUContext, template struct SetConstant<pten::GPUContext, pten::dtype::complex<double>>;
paddle::platform::complex<double>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
paddle::platform::float16>; pten::dtype::float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
paddle::platform::bfloat16>; pten::dtype::bfloat16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>; template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>; template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>; template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
...@@ -71,9 +70,9 @@ template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>; ...@@ -71,9 +70,9 @@ template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>; template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>; template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
paddle::platform::complex<float>>; pten::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
paddle::platform::complex<double>>; pten::dtype::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<paddle::platform::CUDADeviceContext, bool, RANK>; \ template struct Transpose<paddle::platform::CUDADeviceContext, bool, RANK>; \
...@@ -97,10 +96,24 @@ template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, ...@@ -97,10 +96,24 @@ template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
int64_t, \ int64_t, \
RANK>; \ RANK>; \
template struct Transpose<paddle::platform::CUDADeviceContext, \ template struct Transpose<paddle::platform::CUDADeviceContext, \
paddle::platform::complex<float>, \ pten::dtype::complex<float>, \
RANK>; \ RANK>; \
template struct Transpose<paddle::platform::CUDADeviceContext, \ template struct Transpose<paddle::platform::CUDADeviceContext, \
paddle::platform::complex<double>, \ pten::dtype::complex<double>, \
RANK>; \
template struct Transpose<pten::GPUContext, bool, RANK>; \
template struct Transpose<pten::GPUContext, float, RANK>; \
template struct Transpose<pten::GPUContext, double, RANK>; \
template struct Transpose<pten::GPUContext, float16, RANK>; \
template struct Transpose<pten::GPUContext, bfloat16, RANK>; \
template struct Transpose<pten::GPUContext, int8_t, RANK>; \
template struct Transpose<pten::GPUContext, int32_t, RANK>; \
template struct Transpose<pten::GPUContext, int64_t, RANK>; \
template struct Transpose<pten::GPUContext, \
pten::dtype::complex<float>, \
RANK>; \
template struct Transpose<pten::GPUContext, \
pten::dtype::complex<double>, \
RANK>; RANK>;
DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(1);
...@@ -133,60 +146,53 @@ __global__ void TransposeNormalKernel(const T* in_ptr, ...@@ -133,60 +146,53 @@ __global__ void TransposeNormalKernel(const T* in_ptr,
} }
} }
template <typename T> template <typename DeviceContext, typename T>
struct TransposeNormal<paddle::platform::CUDADeviceContext, T> { void TransposeNormal<DeviceContext, T>::operator()(
void operator()(const paddle::platform::CUDADeviceContext& context, const DeviceContext& context,
const paddle::framework::Tensor& in, const paddle::framework::Tensor& in,
paddle::framework::Tensor* out, paddle::framework::Tensor* out,
const std::vector<int>& axis) { const std::vector<int>& axis) {
const int rank = axis.size(); const int rank = axis.size();
auto in_stride = paddle::framework::stride(in.dims()); auto in_stride = paddle::framework::stride(in.dims());
auto out_stride = paddle::framework::stride(out->dims()); auto out_stride = paddle::framework::stride(out->dims());
auto* in_ptr = in.data<T>(); auto* in_ptr = in.data<T>();
auto* out_ptr = out->data<T>(); auto* out_ptr = out->data<T>();
// copy in_stride, out_stride, axis to gpu device
const paddle::platform::CUDAPlace& cuda_place = context.GetPlace();
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = paddle::memory::Alloc(cpu_place, size);
auto cuda_buf_holder = paddle::memory::Alloc(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) {
cpu_buf[i] = in_stride[i];
cpu_buf[rank + i] = out_stride[i];
cpu_buf[2 * rank + i] = axis[i];
}
paddle::memory::Copy(
cuda_place, cuda_buf, cpu_place, cpu_buf, size, context.stream());
REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);
const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock(); // copy in_stride, out_stride, axis to gpu device
const int MAX_GRID_DIM = const paddle::platform::CUDAPlace& cuda_place = context.GetPlace();
context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
int64_t elements = in.numel(); size_t size = 3 * rank * sizeof(int64_t);
int block_size = (elements >= MAX_BLOCK_DIM) auto cpu_buf_holder = paddle::memory::Alloc(cpu_place, size);
? MAX_BLOCK_DIM auto cuda_buf_holder = paddle::memory::Alloc(cuda_place, size);
: (1 << static_cast<int>(std::log2(elements))); REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
int grid_size = elements / block_size; REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size; for (int i = 0; i < rank; ++i) {
TransposeNormalKernel<T><<<grid_size, block_size, 0, context.stream()>>>( cpu_buf[i] = in_stride[i];
in_ptr, cpu_buf[rank + i] = out_stride[i];
out_ptr, cpu_buf[2 * rank + i] = axis[i];
elements,
in_stride_ptr,
out_stride_ptr,
axis_ptr,
rank);
} }
}; paddle::memory::Copy(
cuda_place, cuda_buf, cpu_place, cpu_buf, size, context.stream());
REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);
const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM = context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
int64_t elements = in.numel();
int block_size = (elements >= MAX_BLOCK_DIM)
? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(elements)));
int grid_size = elements / block_size;
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
TransposeNormalKernel<T><<<grid_size, block_size, 0, context.stream()>>>(
in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, rank);
}
// define transpose normal // define transpose normal
#define DEFINE_GPU_TRANS_NORMAL(TYPE) \ #define DEFINE_GPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<paddle::platform::CUDADeviceContext, TYPE> template struct TransposeNormal<paddle::platform::CUDADeviceContext, TYPE>; \
template struct TransposeNormal<pten::GPUContext, TYPE>
DEFINE_GPU_TRANS_NORMAL(float16); DEFINE_GPU_TRANS_NORMAL(float16);
DEFINE_GPU_TRANS_NORMAL(bfloat16); DEFINE_GPU_TRANS_NORMAL(bfloat16);
...@@ -198,8 +204,8 @@ DEFINE_GPU_TRANS_NORMAL(bool); ...@@ -198,8 +204,8 @@ DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t); DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t); DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t); DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>); DEFINE_GPU_TRANS_NORMAL(pten::dtype::complex<float>);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>); DEFINE_GPU_TRANS_NORMAL(pten::dtype::complex<double>);
struct TensorSetConstantGPU { struct TensorSetConstantGPU {
TensorSetConstantGPU(const paddle::platform::DeviceContext& context, TensorSetConstantGPU(const paddle::platform::DeviceContext& context,
...@@ -374,7 +380,7 @@ struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, T> { ...@@ -374,7 +380,7 @@ struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, T> {
}; };
template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
paddle::platform::float16>; pten::dtype::float16>;
} // namespace funcs } // namespace funcs
} // namespace pten } // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/funcs/transpose.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/bfloat16.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/common/float16.h"
namespace pten {
namespace math {
template <typename T>
struct TransposeNormal<CPUContext, T> {
// for dims >= 7 situation
void operator()(const CPUContext& dev_ctx,
const pten::DenseTensor& in,
pten::DenseTensor* out,
const std::vector<int64_t>& axis) {
const int rank = axis.size();
auto in_stride = pten::framework::stride(in.dims());
auto out_stride = pten::framework::stride(out->dims());
const T* in_ptr = in.data<T>();
T* out_ptr = dev_ctx.template Alloc<T>(out);
auto transpose_helper = [&](int64_t beg, int64_t end) {
for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
int64_t in_idx = 0;
int64_t tmp_idx = out_idx;
// calculate the input index
for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride[i];
tmp_idx -= coordinate * out_stride[i];
in_idx += coordinate * in_stride[axis[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
}
};
transpose_helper(0, out->numel());
}
};
// define transpose normal
#define DEFINE_CPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<CPUContext, TYPE>
DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(int32_t);
DEFINE_CPU_TRANS_NORMAL(int64_t);
DEFINE_CPU_TRANS_NORMAL(float);
DEFINE_CPU_TRANS_NORMAL(double);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::float16);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::bfloat16);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<double>);
} // namespace math
} // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/transpose.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/bfloat16.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/common/float16.h"
namespace pten {
namespace math {
#define REINTERPRET(T, DST_PTR, SRC_PTR) \
T* DST_PTR = reinterpret_cast<T*>(SRC_PTR)
template <typename T>
__global__ void TransposeNormalKernel(const T* in_ptr,
T* out_ptr,
int64_t element,
const int64_t* in_stride_ptr,
const int64_t* out_stride_ptr,
const int64_t* axis_ptr,
int rank) {
CUDA_KERNEL_LOOP(out_idx, element) {
int64_t in_idx = 0;
int64_t tmp_idx = out_idx;
for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride_ptr[i];
tmp_idx -= coordinate * out_stride_ptr[i];
in_idx += coordinate * in_stride_ptr[axis_ptr[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
}
}
template <typename T>
struct TransposeNormal<GPUContext, T> {
// for dims >= 7 situation
void operator()(const GPUContext& dev_ctx,
const pten::DenseTensor& in,
pten::DenseTensor* out,
const std::vector<int64_t>& axis) {
const int rank = axis.size();
auto in_stride = pten::framework::stride(in.dims());
auto out_stride = pten::framework::stride(out->dims());
auto* in_ptr = in.data<T>();
T* out_ptr = dev_ctx.template Alloc<T>(out);
// copy in_stride, out_stride, axis to gpu device
const paddle::platform::CUDAPlace& cuda_place = dev_ctx.GetPlace();
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = paddle::memory::Alloc(cpu_place, size);
auto cuda_buf_holder = paddle::memory::Alloc(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) {
cpu_buf[i] = in_stride[i];
cpu_buf[rank + i] = out_stride[i];
cpu_buf[2 * rank + i] = axis[i];
}
paddle::memory::Copy(
cuda_place, cuda_buf, cpu_place, cpu_buf, size, dev_ctx.stream());
REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);
const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM =
dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
int64_t elements = in.numel();
int block_size = (elements >= MAX_BLOCK_DIM)
? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(elements)));
int grid_size = elements / block_size;
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
TransposeNormalKernel<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
in_ptr,
out_ptr,
elements,
in_stride_ptr,
out_stride_ptr,
axis_ptr,
rank);
}
};
// define transpose normal
#define DEFINE_GPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<GPUContext, TYPE>
DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(int32_t);
DEFINE_GPU_TRANS_NORMAL(int64_t);
DEFINE_GPU_TRANS_NORMAL(float);
DEFINE_GPU_TRANS_NORMAL(double);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::float16);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::bfloat16);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>);
} // namespace math
} // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
namespace pten {
namespace math {
template <typename DeviceContext, typename T>
struct TransposeNormal {
// for dims >= 7 situation
void operator()(const DeviceContext& dev_ctx,
const pten::DenseTensor& in,
pten::DenseTensor* out,
const std::vector<int64_t>& axis);
};
template <typename DeviceContext, typename T, int Rank>
struct Transpose {
void operator()(const DeviceContext& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto eigen_in = pten::EigenTensor<T, Rank>::From(in);
auto eigen_out = pten::EigenTensor<T, Rank>::From(*out);
auto* dev = dev_ctx.eigen_device();
// use 32bit index to speed up computation
bool use_32bit_index = eigen_out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) {
To32BitIndex(eigen_out).device(*dev) =
To32BitIndex(eigen_in).shuffle(permute);
} else {
eigen_out.device(*dev) = eigen_in.shuffle(permute);
}
}
};
} // namespace math
} // namespace pten
...@@ -72,16 +72,16 @@ void CastKernel(const Context& dev_ctx, ...@@ -72,16 +72,16 @@ void CastKernel(const Context& dev_ctx,
int16_t, \ int16_t, \
bool, \ bool, \
uint8_t, \ uint8_t, \
paddle::platform::float16, \ pten::dtype::float16, \
paddle::platform::complex<float>, \ pten::dtype::complex<float>, \
paddle::platform::complex<double>, \ pten::dtype::complex<double>, \
##__VA_ARGS__) { \ ##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \ kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \ paddle::experimental::DataType::UNDEFINED); \
} }
#if !defined(PADDLE_WITH_HIP) #if !defined(PADDLE_WITH_HIP)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, pten::dtype::bfloat16)
#else #else
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif #endif
...@@ -25,9 +25,9 @@ PT_REGISTER_KERNEL(conj, ...@@ -25,9 +25,9 @@ PT_REGISTER_KERNEL(conj,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ConjKernel, pten::ConjKernel,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>, pten::dtype::complex<double>,
float, float,
double, double,
int, int,
......
...@@ -120,7 +120,7 @@ PT_REGISTER_KERNEL(concat, ...@@ -120,7 +120,7 @@ PT_REGISTER_KERNEL(concat,
int64_t, int64_t,
int, int,
uint8_t, uint8_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -28,5 +28,5 @@ PT_REGISTER_KERNEL(dot_grad, ...@@ -28,5 +28,5 @@ PT_REGISTER_KERNEL(dot_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -49,8 +49,8 @@ void DotKernel(const Context& dev_ctx, ...@@ -49,8 +49,8 @@ void DotKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::pten::dtype::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::pten::dtype::complex<double>;
PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(dot,
GPU, GPU,
......
...@@ -128,9 +128,9 @@ PT_REGISTER_KERNEL(add_grad, ...@@ -128,9 +128,9 @@ PT_REGISTER_KERNEL(add_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(add_double_grad, PT_REGISTER_KERNEL(add_double_grad,
GPU, GPU,
...@@ -140,9 +140,9 @@ PT_REGISTER_KERNEL(add_double_grad, ...@@ -140,9 +140,9 @@ PT_REGISTER_KERNEL(add_double_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(add_triple_grad, PT_REGISTER_KERNEL(add_triple_grad,
GPU, GPU,
...@@ -152,9 +152,9 @@ PT_REGISTER_KERNEL(add_triple_grad, ...@@ -152,9 +152,9 @@ PT_REGISTER_KERNEL(add_triple_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(subtract_grad, PT_REGISTER_KERNEL(subtract_grad,
GPU, GPU,
...@@ -164,9 +164,9 @@ PT_REGISTER_KERNEL(subtract_grad, ...@@ -164,9 +164,9 @@ PT_REGISTER_KERNEL(subtract_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(subtract_double_grad, PT_REGISTER_KERNEL(subtract_double_grad,
GPU, GPU,
...@@ -176,6 +176,6 @@ PT_REGISTER_KERNEL(subtract_double_grad, ...@@ -176,6 +176,6 @@ PT_REGISTER_KERNEL(subtract_double_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -24,6 +24,6 @@ PT_REGISTER_KERNEL(expand_grad, ...@@ -24,6 +24,6 @@ PT_REGISTER_KERNEL(expand_grad,
pten::ExpandGradKernel, pten::ExpandGradKernel,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
int, int,
int64_t) {} int64_t) {}
...@@ -25,7 +25,7 @@ PT_REGISTER_KERNEL(expand, ...@@ -25,7 +25,7 @@ PT_REGISTER_KERNEL(expand,
pten::ExpandKernel, pten::ExpandKernel,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
int, int,
int64_t, int64_t,
bool) {} bool) {}
...@@ -106,9 +106,9 @@ PT_REGISTER_KERNEL(full, ...@@ -106,9 +106,9 @@ PT_REGISTER_KERNEL(full,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
GPU, GPU,
...@@ -119,4 +119,4 @@ PT_REGISTER_KERNEL(full_like, ...@@ -119,4 +119,4 @@ PT_REGISTER_KERNEL(full_like,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16) {} pten::dtype::float16) {}
...@@ -91,9 +91,9 @@ DEFINE_CUDA_ELEMENTWISE_OP(Divide) ...@@ -91,9 +91,9 @@ DEFINE_CUDA_ELEMENTWISE_OP(Divide)
} // namespace pten } // namespace pten
using float16 = paddle::platform::float16; using float16 = pten::dtype::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::pten::dtype::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::pten::dtype::complex<double>;
PT_REGISTER_KERNEL(add_raw, PT_REGISTER_KERNEL(add_raw,
GPU, GPU,
......
...@@ -25,10 +25,10 @@ PT_REGISTER_KERNEL(matmul_grad, ...@@ -25,10 +25,10 @@ PT_REGISTER_KERNEL(matmul_grad,
pten::MatmulGradKernel, pten::MatmulGradKernel,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(matmul_double_grad, PT_REGISTER_KERNEL(matmul_double_grad,
GPU, GPU,
...@@ -36,9 +36,9 @@ PT_REGISTER_KERNEL(matmul_double_grad, ...@@ -36,9 +36,9 @@ PT_REGISTER_KERNEL(matmul_double_grad,
pten::MatmulDoubleGradKernel, pten::MatmulDoubleGradKernel,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
PT_REGISTER_KERNEL(matmul_triple_grad, PT_REGISTER_KERNEL(matmul_triple_grad,
GPU, GPU,
...@@ -46,6 +46,6 @@ PT_REGISTER_KERNEL(matmul_triple_grad, ...@@ -46,6 +46,6 @@ PT_REGISTER_KERNEL(matmul_triple_grad,
pten::MatmulTripleGradKernel, pten::MatmulTripleGradKernel,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -26,7 +26,7 @@ PT_REGISTER_KERNEL(matmul, ...@@ -26,7 +26,7 @@ PT_REGISTER_KERNEL(matmul,
pten::MatmulKernel, pten::MatmulKernel,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::bfloat16, pten::dtype::bfloat16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -117,4 +117,4 @@ PT_REGISTER_KERNEL(norm_grad, ...@@ -117,4 +117,4 @@ PT_REGISTER_KERNEL(norm_grad,
pten::NormGradKernel, pten::NormGradKernel,
float, float,
double, double,
paddle::platform::float16) {} pten::dtype::float16) {}
...@@ -130,4 +130,4 @@ PT_REGISTER_KERNEL(norm, ...@@ -130,4 +130,4 @@ PT_REGISTER_KERNEL(norm,
pten::NormKernel, pten::NormKernel,
float, float,
double, double,
paddle::platform::float16) {} pten::dtype::float16) {}
...@@ -1004,15 +1004,14 @@ template <typename Tx, ...@@ -1004,15 +1004,14 @@ template <typename Tx,
typename Ty, typename Ty,
template <typename> class ReduceOp, template <typename> class ReduceOp,
typename TransformOp> typename TransformOp>
static static typename std::enable_if<!std::is_same<Tx, pten::dtype::float16>::value,
typename std::enable_if<!std::is_same<Tx, paddle::platform::float16>::value, void>::type
void>::type CubTensorReduceImpl(const Tx* x_data,
CubTensorReduceImpl(const Tx* x_data, Ty* y_data,
Ty* y_data, const TransformOp& transform,
const TransformOp& transform, int reduce_num,
int reduce_num, const paddle::platform::Place& place,
const paddle::platform::Place& place, gpuStream_t stream) {
gpuStream_t stream) {
auto reducer = ReduceOp<Ty>(); auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data, cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform); transform);
...@@ -1048,15 +1047,14 @@ template <typename Tx, ...@@ -1048,15 +1047,14 @@ template <typename Tx,
typename Ty, typename Ty,
template <typename> class ReduceOp, template <typename> class ReduceOp,
typename TransformOp> typename TransformOp>
static static typename std::enable_if<std::is_same<Tx, pten::dtype::float16>::value,
typename std::enable_if<std::is_same<Tx, paddle::platform::float16>::value, void>::type
void>::type CubTensorReduceImpl(const Tx* x_data,
CubTensorReduceImpl(const Tx* x_data, Ty* y_data,
Ty* y_data, const TransformOp& transform,
const TransformOp& transform, int reduce_num,
int reduce_num, const paddle::platform::Place& place,
const paddle::platform::Place& place, gpuStream_t stream) {
gpuStream_t stream) {
PADDLE_THROW(pten::errors::InvalidArgument( PADDLE_THROW(pten::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce().")); "Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
} }
...@@ -1099,7 +1097,7 @@ void TensorReduceImpl(const pten::GPUContext& dev_ctx, ...@@ -1099,7 +1097,7 @@ void TensorReduceImpl(const pten::GPUContext& dev_ctx,
} }
config.SetOutputData(y_data, x.place(), &tmp); config.SetOutputData(y_data, x.place(), &tmp);
constexpr bool kIsTxFP16 = std::is_same<Tx, paddle::platform::float16>::value; constexpr bool kIsTxFP16 = std::is_same<Tx, pten::dtype::float16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
if (use_cub_reduce) { if (use_cub_reduce) {
CubTensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>( CubTensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>(
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/float16.h" #include "paddle/pten/common/float16.h"
using float16 = paddle::platform::float16; using float16 = pten::dtype::float16;
PT_REGISTER_KERNEL( PT_REGISTER_KERNEL(
sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {} sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {}
...@@ -47,10 +47,9 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -47,10 +47,9 @@ void FullLikeKernel(const Context& dev_ctx,
auto value = val.to<float>(); auto value = val.to<float>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
typename std::conditional< typename std::conditional<std::is_same<T, pten::dtype::float16>::value,
std::is_same<T, paddle::platform::float16>::value, float,
float, T>::type>::type;
T>::type>::type;
auto common_type_value = static_cast<CommonType>(value); auto common_type_value = static_cast<CommonType>(value);
......
...@@ -90,7 +90,7 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, ...@@ -90,7 +90,7 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx,
DenseTensor output = EmptyLike<T, Context>(dev_ctx, input); DenseTensor output = EmptyLike<T, Context>(dev_ctx, input);
output.Resize({in_dims[1], in_dims[0], in_dims[2]}); output.Resize({in_dims[1], in_dims[0], in_dims[2]});
std::vector<int> axis = {1, 0, 2}; std::vector<int> axis = {1, 0, 2};
math::Transpose<Context, T, 3> trans; funcs::Transpose<Context, T, 3> trans;
trans(dev_ctx, input, &output, axis); trans(dev_ctx, input, &output, axis);
output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output; return output;
......
...@@ -78,8 +78,8 @@ void MultiplyKernel(const Context& dev_ctx, ...@@ -78,8 +78,8 @@ void MultiplyKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::pten::dtype::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::pten::dtype::complex<double>;
PT_REGISTER_KERNEL( PT_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {}
...@@ -91,7 +91,7 @@ PT_REGISTER_KERNEL(sum, ...@@ -91,7 +91,7 @@ PT_REGISTER_KERNEL(sum,
bool, bool,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -149,7 +149,7 @@ PT_REGISTER_KERNEL(mean, ...@@ -149,7 +149,7 @@ PT_REGISTER_KERNEL(mean,
float, float,
double, double,
bool, bool,
paddle::platform::float16) {} pten::dtype::float16) {}
PT_REGISTER_KERNEL(sum, PT_REGISTER_KERNEL(sum,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(sum, ...@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(sum,
bool, bool,
float, float,
double, double,
paddle::platform::float16, pten::dtype::float16,
int, int,
int64_t, int64_t,
complex64, complex64,
...@@ -172,7 +172,7 @@ PT_REGISTER_KERNEL(add, ...@@ -172,7 +172,7 @@ PT_REGISTER_KERNEL(add,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
...@@ -183,7 +183,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -183,7 +183,7 @@ PT_REGISTER_KERNEL(subtract,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(divide, PT_REGISTER_KERNEL(divide,
...@@ -194,7 +194,7 @@ PT_REGISTER_KERNEL(divide, ...@@ -194,7 +194,7 @@ PT_REGISTER_KERNEL(divide,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
...@@ -206,7 +206,7 @@ PT_REGISTER_KERNEL(multiply, ...@@ -206,7 +206,7 @@ PT_REGISTER_KERNEL(multiply,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, pten::dtype::float16,
complex64, complex64,
complex128) {} complex128) {}
#endif #endif
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/all_context.h" #include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/transpose.h" #include "paddle/pten/kernels/funcs/math_function.h"
namespace pten { namespace pten {
...@@ -42,7 +42,7 @@ void CastDataLayout(const Context& dev_ctx, ...@@ -42,7 +42,7 @@ void CastDataLayout(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axis, const std::vector<int>& axis,
DenseTensor* out) { DenseTensor* out) {
math::Transpose<Context, T, 4> trans4; funcs::Transpose<Context, T, 4> trans4;
trans4(dev_ctx, x, out, axis); trans4(dev_ctx, x, out, axis);
} }
......
...@@ -162,7 +162,7 @@ static void ScaleGPU(DataType kernel_dtype, ...@@ -162,7 +162,7 @@ static void ScaleGPU(DataType kernel_dtype,
break; break;
} }
case pten::DataType::FLOAT16: { case pten::DataType::FLOAT16: {
pten::ScaleKernel<paddle::platform::float16>( pten::ScaleKernel<pten::dtype::float16>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out); dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break; break;
} }
......
...@@ -13,3 +13,11 @@ cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils) ...@@ -13,3 +13,11 @@ cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils) cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS pten pten_api_utils) cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils) cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_math_function SRCS test_math_function.cc DEPS math_function)
if(WITH_GPU)
nv_test(test_math_function_gpu SRCS test_math_function.cu DEPS math_function)
endif()
if(WITH_ROCM)
hip_test(test_math_function_gpu SRCS test_math_function.cu DEPS math_function)
endif()
...@@ -11,9 +11,13 @@ ...@@ -11,9 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/pten/kernels/funcs/math_function.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/pten/kernels/funcs/math_function.h"
namespace pten {
namespace tests {
template <typename T> template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T> inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
...@@ -348,3 +352,6 @@ TEST(math_function, gemm_warp) { ...@@ -348,3 +352,6 @@ TEST(math_function, gemm_warp) {
GemmWarpTest<double>(8, 5, 6, 1.0, 0.0); GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
GemmWarpTest<double>(8, 5, 6, 2.0, 1.0); GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
} }
} // namespace tests
} // namespace pten
...@@ -11,12 +11,16 @@ ...@@ -11,12 +11,16 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/kernels/funcs/math_function.h" #include "paddle/pten/kernels/funcs/math_function.h"
void fill_fp16_data(paddle::platform::float16* in_ptr, namespace pten {
namespace tests {
void fill_fp16_data(pten::dtype::float16* in_ptr,
size_t size, size_t size,
const std::vector<float>& data) { const std::vector<float>& data) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -28,7 +32,7 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, ...@@ -28,7 +32,7 @@ void fill_fp16_data(paddle::platform::float16* in_ptr,
size, size,
data.size())); data.size()));
for (size_t i = 0; i < data.size(); ++i) { for (size_t i = 0; i < data.size(); ++i) {
in_ptr[i] = paddle::platform::float16(data[i]); in_ptr[i] = pten::dtype::float16(data[i]);
} }
} }
...@@ -95,27 +99,26 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -95,27 +99,26 @@ TEST(math_function, notrans_mul_trans_fp16) {
return; return;
} }
paddle::platform::float16* input1_ptr = pten::dtype::float16* input1_ptr =
input1.mutable_data<paddle::platform::float16>({2, 3}, cpu_place); input1.mutable_data<pten::dtype::float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu);
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu);
out_gpu.mutable_data<paddle::platform::float16>({2, 2}, gpu_place); out_gpu.mutable_data<pten::dtype::float16>({2, 2}, gpu_place);
GetBlas<paddle::platform::float16>(context).MatMul( GetBlas<pten::dtype::float16>(context).MatMul(input1_gpu,
input1_gpu, false,
false, input2_gpu,
input2_gpu, true,
true, pten::dtype::float16(1),
paddle::platform::float16(1), &out_gpu,
&out_gpu, pten::dtype::float16(0));
paddle::platform::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
paddle::platform::float16* out_ptr = out.data<paddle::platform::float16>(); pten::dtype::float16* out_ptr = out.data<pten::dtype::float16>();
context.Wait(); context.Wait();
EXPECT_EQ(static_cast<float>(out_ptr[0]), 5); EXPECT_EQ(static_cast<float>(out_ptr[0]), 5);
EXPECT_EQ(static_cast<float>(out_ptr[1]), 14); EXPECT_EQ(static_cast<float>(out_ptr[1]), 14);
...@@ -185,27 +188,26 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -185,27 +188,26 @@ TEST(math_function, trans_mul_notrans_fp16) {
return; return;
} }
paddle::platform::float16* input1_ptr = pten::dtype::float16* input1_ptr =
input1.mutable_data<paddle::platform::float16>({2, 3}, cpu_place); input1.mutable_data<pten::dtype::float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu);
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu);
out_gpu.mutable_data<paddle::platform::float16>({3, 3}, gpu_place); out_gpu.mutable_data<pten::dtype::float16>({3, 3}, gpu_place);
GetBlas<paddle::platform::float16>(context).MatMul( GetBlas<pten::dtype::float16>(context).MatMul(input1_gpu,
input1_gpu, true,
true, input2_gpu,
input2_gpu, false,
false, pten::dtype::float16(1),
paddle::platform::float16(1), &out_gpu,
&out_gpu, pten::dtype::float16(0));
paddle::platform::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
paddle::platform::float16* out_ptr = out.data<paddle::platform::float16>(); pten::dtype::float16* out_ptr = out.data<pten::dtype::float16>();
context.Wait(); context.Wait();
EXPECT_EQ(static_cast<float>(out_ptr[0]), 9); EXPECT_EQ(static_cast<float>(out_ptr[0]), 9);
EXPECT_EQ(static_cast<float>(out_ptr[1]), 12); EXPECT_EQ(static_cast<float>(out_ptr[1]), 12);
...@@ -300,37 +302,37 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -300,37 +302,37 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
int m = 2; int m = 2;
int n = 3; int n = 3;
int k = 3; int k = 3;
paddle::platform::float16* input1_ptr = pten::dtype::float16* input1_ptr =
input1.mutable_data<paddle::platform::float16>({2, 3}, cpu_place); input1.mutable_data<pten::dtype::float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
paddle::platform::float16* input2_ptr = pten::dtype::float16* input2_ptr =
input2.mutable_data<paddle::platform::float16>({3, 4}, cpu_place); input2.mutable_data<pten::dtype::float16>({3, 4}, cpu_place);
fill_fp16_data( fill_fp16_data(
input2_ptr, input2.numel(), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); input2_ptr, input2.numel(), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
paddle::platform::float16* input3_ptr = pten::dtype::float16* input3_ptr =
input3.mutable_data<paddle::platform::float16>({2, 4}, cpu_place); input3.mutable_data<pten::dtype::float16>({2, 4}, cpu_place);
fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7}); fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu);
paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu); paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu);
paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu); paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu);
paddle::platform::float16* a = input1_gpu.data<paddle::platform::float16>(); pten::dtype::float16* a = input1_gpu.data<pten::dtype::float16>();
paddle::platform::float16* b = input2_gpu.data<paddle::platform::float16>(); pten::dtype::float16* b = input2_gpu.data<pten::dtype::float16>();
paddle::platform::float16* c = pten::dtype::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place); input3_gpu.mutable_data<pten::dtype::float16>(gpu_place);
GetBlas<paddle::platform::float16>(context).GEMM( GetBlas<pten::dtype::float16>(context).GEMM(
false, false,
false, false,
m, m,
n, n,
k, k,
static_cast<paddle::platform::float16>(1), static_cast<pten::dtype::float16>(1),
a, a,
3, 3,
b + 1, b + 1,
4, 4,
static_cast<paddle::platform::float16>(1), static_cast<pten::dtype::float16>(1),
c + 1, c + 1,
4); 4);
...@@ -429,37 +431,37 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -429,37 +431,37 @@ TEST(math_function, gemm_trans_cublas_fp16) {
int m = 2; int m = 2;
int n = 3; int n = 3;
int k = 3; int k = 3;
paddle::platform::float16* input1_ptr = pten::dtype::float16* input1_ptr =
input1.mutable_data<paddle::platform::float16>({2, 3}, cpu_place); input1.mutable_data<pten::dtype::float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
paddle::platform::float16* input2_ptr = pten::dtype::float16* input2_ptr =
input2.mutable_data<paddle::platform::float16>({4, 3}, cpu_place); input2.mutable_data<pten::dtype::float16>({4, 3}, cpu_place);
fill_fp16_data( fill_fp16_data(
input2_ptr, input2.numel(), {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}); input2_ptr, input2.numel(), {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11});
paddle::platform::float16* input3_ptr = pten::dtype::float16* input3_ptr =
input3.mutable_data<paddle::platform::float16>({2, 4}, cpu_place); input3.mutable_data<pten::dtype::float16>({2, 4}, cpu_place);
fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7}); fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu);
paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu); paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu);
paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu); paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu);
paddle::platform::float16* a = input1_gpu.data<paddle::platform::float16>(); pten::dtype::float16* a = input1_gpu.data<pten::dtype::float16>();
paddle::platform::float16* b = input2_gpu.data<paddle::platform::float16>(); pten::dtype::float16* b = input2_gpu.data<pten::dtype::float16>();
paddle::platform::float16* c = pten::dtype::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place); input3_gpu.mutable_data<pten::dtype::float16>(gpu_place);
GetBlas<paddle::platform::float16>(context).GEMM( GetBlas<pten::dtype::float16>(context).GEMM(
false, false,
true, true,
m, m,
n, n,
k, k,
static_cast<paddle::platform::float16>(1), static_cast<pten::dtype::float16>(1),
a, a,
3, 3,
b + 3, b + 3,
3, 3,
static_cast<paddle::platform::float16>(1), static_cast<pten::dtype::float16>(1),
c + 1, c + 1,
4); 4);
...@@ -547,3 +549,6 @@ TEST(math_function, gemv) { ...@@ -547,3 +549,6 @@ TEST(math_function, gemv) {
GemvTest<float>(3, 13, true); GemvTest<float>(3, 13, true);
GemvTest<double>(3, 13, true); GemvTest<double>(3, 13, true);
} }
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册