提交 2c188a20 编写于 作者: Y Yi Wang

Follow QingQing's suggestion

上级 e2e0fbd4
......@@ -20,23 +20,19 @@ class Tensor {
using paddle::platform::get_place;
public:
explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {}
explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {}
template <typename T>
const T* data() const {
PADDLE_ASSERT(holder_ != nullptr);
PADDLE_ASSERT(holder_->Place() == place_);
PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T));
PADDLE_ASSERT(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data");
return static_cast<const T*>(holder->Ptr());
}
template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data() {
if (holder_ == nullptr || holder_->Place() != place_ ||
holder_->Size() < dims_.product() * sizeof(T)) {
holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T)));
T* mutable_data(DDim dims, Place place) {
if (holder_ == nullptr || holder_->Place() != place ||
holder_->Size() < dims.product() * sizeof(T)) {
holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T)));
}
return static_cast<T*>(holder_->Ptr());
}
......@@ -44,16 +40,7 @@ class Tensor {
template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(DDim dims) {
dims_ = dims;
return mutable_data<T>();
}
template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(DDim dims, Place place) {
dims_ = dims;
place_ = place;
return mutable_data<T>();
return mutable_data<T>(dims, paddle::platform::get_place());
}
private:
......@@ -69,7 +56,7 @@ class Tensor {
template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place pl, size_t size)
: ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
: ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
place_(pl),
size_(size) {}
......@@ -83,8 +70,6 @@ class Tensor {
};
std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_; // could be smallers than the holder_->Size().
paddle::platform::Place place_;
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册