提交 da07ec18 编写于 作者: F fengjiayi

Update Tensor and PODDeleter's template parameter

1. Change PODDeleter's template parameter 'PlaceType' to 'Place'.

2. Limit PODDeleter and Tensor::mutable_data()'s `T` to POD type.
上级 6cd94cc7
...@@ -60,13 +60,15 @@ class Tensor { ...@@ -60,13 +60,15 @@ class Tensor {
offset_); offset_);
} }
template <typename T> template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, platform::Place place) { T* mutable_data(DDim dims, platform::Place place) {
Resize(dims); Resize(dims);
return mutable_data<T>(place); return mutable_data<T>(place);
} }
template <typename T> template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(platform::Place place) { T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(product(dims_) > 0, PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call " "Tensor's numel must be larger than zero to call "
...@@ -150,7 +152,7 @@ class Tensor { ...@@ -150,7 +152,7 @@ class Tensor {
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(PlaceType place, size_t size) PlaceholderImpl(PlaceType place, size_t size)
: ptr_(static_cast<T*>(memory::Alloc(place, size)), : ptr_(static_cast<T*>(memory::Alloc(place, size)),
memory::PodDeleter<T, PlaceType>(place)), memory::PODDeleter<T, PlaceType>(place)),
place_(place), place_(place),
size_(size) {} size_(size) {}
...@@ -159,7 +161,7 @@ class Tensor { ...@@ -159,7 +161,7 @@ class Tensor {
virtual paddle::platform::Place place() const { return place_; } virtual paddle::platform::Place place() const { return place_; }
virtual std::type_index type() const { return std::type_index(typeid(T)); } virtual std::type_index type() const { return std::type_index(typeid(T)); }
std::unique_ptr<T, memory::PodDeleter<T, PlaceType>> ptr_; std::unique_ptr<T, memory::PODDeleter<T, PlaceType>> ptr_;
platform::Place place_; // record the place of ptr_. platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block. size_t size_; // size of the memory block.
}; };
......
...@@ -28,14 +28,16 @@ void Free(Place, void*); ...@@ -28,14 +28,16 @@ void Free(Place, void*);
template <class Place> template <class Place>
size_t Used(Place); size_t Used(Place);
template <typename T, typename PlaceType> template <typename T,
class PodDeleter { typename Place /* platform::GPUPlace or platform::CPUPlace */,
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
class PODDeleter {
public: public:
PodDeleter(PlaceType place) : place_(place) {} PODDeleter(Place place) : place_(place) {}
void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); } void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); }
private: private:
PlaceType place_; Place place_;
}; };
} // namespace memory } // namespace memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册