diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 8d658d509726d9e0ad7529da3b5fefd18100ef75..7fa662fbb5497a4d72912493589b819d4180403b 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -14,32 +14,39 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include "paddle/framework/ddim.h" +#include "paddle/framework/enforce.h" +#include "paddle/memory/memory.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/place.h" + namespace paddle { namespace framework { class Tensor { - using paddle::platform::Place; - public: template const T* data() const { PADDLE_ENFORCE(holder_ != nullptr, "Tensor::data must be called after Tensor::mutable_data"); - return static_cast(holder->Ptr()); + return static_cast(holder_->Ptr()); } template ::value>::type> - T* mutable_data(DDim dims, Place place) { + typename std::enable_if::value>::type* = nullptr> + T* mutable_data(DDim dims, paddle::platform::Place place) { if (holder_ == nullptr || holder_->Place() != place || - holder_->Size() < dims.product() * sizeof(T)) { - holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T))); + holder_->Size() < product(dims) * sizeof(T)) { + holder_.reset(new PlaceholderImpl(place, product(dims) * sizeof(T))); } return static_cast(holder_->Ptr()); } template ::value>::type> + typename std::enable_if::value>::type* = nullptr> T* mutable_data(DDim dims) { return mutable_data(dims, paddle::platform::get_place()); } @@ -50,24 +57,24 @@ class Tensor { struct Placeholder { virtual ~Placeholder() {} virtual void* Ptr() const = 0; - virtual Place Place() const = 0; + virtual paddle::platform::Place Place() const = 0; virtual size_t Size() const = 0; }; template struct PlaceholderImpl : public Placeholder { - PlaceholderImpl(Place pl, size_t size) + PlaceholderImpl(paddle::platform::Place pl, size_t size) : ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)), place_(pl), size_(size) {} virtual void* Ptr() const { return static_cast(ptr_.get()); } virtual size_t Size() const { return size_; } - virtual Place Place() const { return place_; } + virtual paddle::platform::Place Place() const { return place_; } std::unique_ptr ptr_; - Place place_; // record the place of ptr_. - size_t size_; // size of the memory block. + paddle::platform::Place place_; // record the place of ptr_. + size_t size_; // size of the memory block. }; std::unique_ptr holder_; // holds the memory block if allocated.