From da07ec1886669deb71931d6bb949d20c01033605 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 21 Jul 2017 11:58:42 +0800 Subject: [PATCH] Update Tensor and PODDeleter's template parameter 1. Change PODDeleter's template parameter 'PlaceType' to 'Place'. 2. Limit PODDeleter and Tensor::mutable_data()'s `T` to POD type. --- paddle/framework/tensor.h | 10 ++++++---- paddle/memory/memory.h | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 7d0f74fc5bc..a36f375d2e4 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -60,13 +60,15 @@ class Tensor { offset_); } - template + template ::value>::type* = nullptr> T* mutable_data(DDim dims, platform::Place place) { Resize(dims); return mutable_data(place); } - template + template ::value>::type* = nullptr> T* mutable_data(platform::Place place) { PADDLE_ENFORCE(product(dims_) > 0, "Tensor's numel must be larger than zero to call " @@ -150,7 +152,7 @@ class Tensor { struct PlaceholderImpl : public Placeholder { PlaceholderImpl(PlaceType place, size_t size) : ptr_(static_cast(memory::Alloc(place, size)), - memory::PodDeleter(place)), + memory::PODDeleter(place)), place_(place), size_(size) {} @@ -159,7 +161,7 @@ class Tensor { virtual paddle::platform::Place place() const { return place_; } virtual std::type_index type() const { return std::type_index(typeid(T)); } - std::unique_ptr> ptr_; + std::unique_ptr> ptr_; platform::Place place_; // record the place of ptr_. size_t size_; // size of the memory block. }; diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index f5890fb8445..c4fe1e52203 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -28,14 +28,16 @@ void Free(Place, void*); template size_t Used(Place); -template -class PodDeleter { +template ::value>::type* = nullptr> +class PODDeleter { public: - PodDeleter(PlaceType place) : place_(place) {} + PODDeleter(Place place) : place_(place) {} void operator()(T* ptr) { Free(place_, static_cast(ptr)); } private: - PlaceType place_; + Place place_; }; } // namespace memory -- GitLab