提交 78bd815e 编写于 作者: F fengjiayi

refine conditional compilation and remove `numel_`

上级 1cd14f66
...@@ -27,7 +27,7 @@ namespace framework { ...@@ -27,7 +27,7 @@ namespace framework {
class Tensor { class Tensor {
public: public:
Tensor() : numel_(0), offset_(0) {} Tensor() : offset_(0) {}
template <typename T> template <typename T>
const T* data() const { const T* data() const {
...@@ -44,30 +44,26 @@ class Tensor { ...@@ -44,30 +44,26 @@ class Tensor {
template <typename T> template <typename T>
T* mutable_data(platform::Place place) { T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(numel_ > 0, PADDLE_ENFORCE(product(dims_) > 0,
"Tensor::numel_ must be larger than zero to call " "Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."); "Tensor::mutable_data. Call Tensor::set_dim first.");
if (holder_ == nullptr || if (holder_ == nullptr ||
!(holder_->place() == !(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */ place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) { || holder_->size() < product(dims_) * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
} else if (platform::is_gpu_place(place)) {
#ifdef __CUDACC__ #ifdef __CUDACC__
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T))); boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
#else #else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device.");
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
#endif #endif
} else {
PADDLE_ENFORCE(true, "Unknown 'place'.");
}
offset_ = 0; offset_ = 0;
} }
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
...@@ -88,7 +84,7 @@ class Tensor { ...@@ -88,7 +84,7 @@ class Tensor {
platform::is_cpu_place(dst_place), platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now."); "Tensor::CopyFrom only support CPU now.");
src.CheckDims<T>(); src.CheckDims<T>();
size_t size = src.numel_ * sizeof(T); size_t size = product(src.dims_) * sizeof(T);
set_dims(src.dims()); set_dims(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>()); const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place)); void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
...@@ -122,7 +118,6 @@ class Tensor { ...@@ -122,7 +118,6 @@ class Tensor {
return; return;
} }
dims_ = dims; dims_ = dims;
numel_ = product(dims_);
} }
DDim dims() const { return dims_; } DDim dims() const { return dims_; }
...@@ -170,16 +165,15 @@ class Tensor { ...@@ -170,16 +165,15 @@ class Tensor {
inline void CheckDims() const { inline void CheckDims() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first."); "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_, PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."); "first to re-allocate memory.");
} }
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated. std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_; DDim dims_;
size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area. size_t offset_; // marks the begin of tensor data area.
}; // namespace framework };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册