提交 2c188a20 编写于 作者: Y Yi Wang

Follow QingQing's suggestion

上级 e2e0fbd4
...@@ -20,23 +20,19 @@ class Tensor { ...@@ -20,23 +20,19 @@ class Tensor {
using paddle::platform::get_place; using paddle::platform::get_place;
public: public:
explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {}
explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {}
template <typename T> template <typename T>
const T* data() const { const T* data() const {
PADDLE_ASSERT(holder_ != nullptr); PADDLE_ASSERT(holder_ != nullptr,
PADDLE_ASSERT(holder_->Place() == place_); "Tensor::data must be called after Tensor::mutable_data");
PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T));
return static_cast<const T*>(holder->Ptr()); return static_cast<const T*>(holder->Ptr());
} }
template <typename T, // must be POD types template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type> typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data() { T* mutable_data(DDim dims, Place place) {
if (holder_ == nullptr || holder_->Place() != place_ || if (holder_ == nullptr || holder_->Place() != place ||
holder_->Size() < dims_.product() * sizeof(T)) { holder_->Size() < dims.product() * sizeof(T)) {
holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T))); holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T)));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
} }
...@@ -44,16 +40,7 @@ class Tensor { ...@@ -44,16 +40,7 @@ class Tensor {
template <typename T, // must be POD types template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type> typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(DDim dims) { T* mutable_data(DDim dims) {
dims_ = dims; return mutable_data<T>(dims, paddle::platform::get_place());
return mutable_data<T>();
}
template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(DDim dims, Place place) {
dims_ = dims;
place_ = place;
return mutable_data<T>();
} }
private: private:
...@@ -69,7 +56,7 @@ class Tensor { ...@@ -69,7 +56,7 @@ class Tensor {
template <typename T> template <typename T>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place pl, size_t size) PlaceholderImpl(Place pl, size_t size)
: ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)), : ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
place_(pl), place_(pl),
size_(size) {} size_(size) {}
...@@ -83,8 +70,6 @@ class Tensor { ...@@ -83,8 +70,6 @@ class Tensor {
}; };
std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated. std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_; // could be smallers than the holder_->Size().
paddle::platform::Place place_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册