提交 78bd815e 编写于 作者: F fengjiayi

refine conditional compilation and remove `numel_`

上级 1cd14f66
......@@ -27,7 +27,7 @@ namespace framework {
class Tensor {
public:
Tensor() : numel_(0), offset_(0) {}
Tensor() : offset_(0) {}
template <typename T>
const T* data() const {
......@@ -44,30 +44,26 @@ class Tensor {
template <typename T>
T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(numel_ > 0,
"Tensor::numel_ must be larger than zero to call "
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
if (holder_ == nullptr ||
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) {
|| holder_->size() < product(dims_) * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
} else if (platform::is_gpu_place(place)) {
#ifdef __CUDACC__
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
#else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device.");
#endif
} else {
PADDLE_ENFORCE(true, "Unknown 'place'.");
}
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
......@@ -88,7 +84,7 @@ class Tensor {
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
src.CheckDims<T>();
size_t size = src.numel_ * sizeof(T);
size_t size = product(src.dims_) * sizeof(T);
set_dims(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
......@@ -122,7 +118,6 @@ class Tensor {
return;
}
dims_ = dims;
numel_ = product(dims_);
}
DDim dims() const { return dims_; }
......@@ -170,16 +165,15 @@ class Tensor {
inline void CheckDims() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_,
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_;
size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area.
}; // namespace framework
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册