提交 efd98097 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #2926 from Canpio/dev_enable_tensor_test

refine tensor's conditional compilation, remove Tensor::numel_ and add DDim::size()
...@@ -117,6 +117,8 @@ int DDim::operator[](int idx) const { ...@@ -117,6 +117,8 @@ int DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var); return boost::apply_visitor(DynamicConstIndexer(idx), var);
} }
ssize_t DDim::size() const { return arity(*this); }
bool DDim::operator==(DDim d) const { bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) { if (var.which() != d.getVar().which()) {
return false; return false;
......
...@@ -50,6 +50,8 @@ struct DDim { ...@@ -50,6 +50,8 @@ struct DDim {
DDimVar getVar() { return var; } DDimVar getVar() { return var; }
ssize_t size() const;
bool operator==(DDim d) const; bool operator==(DDim d) const;
bool operator!=(DDim d) const; bool operator!=(DDim d) const;
......
...@@ -49,6 +49,7 @@ TEST(DDim, Equality) { ...@@ -49,6 +49,7 @@ TEST(DDim, Equality) {
// arity of a DDim // arity of a DDim
EXPECT_EQ(paddle::framework::arity(ddim), 3); EXPECT_EQ(paddle::framework::arity(ddim), 3);
EXPECT_EQ(ddim.size(), 3);
// product of a DDim // product of a DDim
EXPECT_EQ(paddle::framework::product(vddim), 45); EXPECT_EQ(paddle::framework::product(vddim), 45);
......
...@@ -27,7 +27,7 @@ namespace framework { ...@@ -27,7 +27,7 @@ namespace framework {
class Tensor { class Tensor {
public: public:
Tensor() : numel_(0), offset_(0) {} Tensor() : offset_(0) {}
template <typename T> template <typename T>
const T* data() const { const T* data() const {
...@@ -44,30 +44,26 @@ class Tensor { ...@@ -44,30 +44,26 @@ class Tensor {
template <typename T> template <typename T>
T* mutable_data(platform::Place place) { T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(numel_ > 0, PADDLE_ENFORCE(product(dims_) > 0,
"Tensor::numel_ must be larger than zero to call " "Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."); "Tensor::mutable_data. Call Tensor::set_dim first.");
if (holder_ == nullptr || if (holder_ == nullptr ||
!(holder_->place() == !(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */ place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) { || holder_->size() < product(dims_) * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
} else if (platform::is_gpu_place(place)) {
#ifdef __CUDACC__ #ifdef __CUDACC__
switch (place.which()) { holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
case 0: boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
#else #else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device.");
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
#endif #endif
} else {
PADDLE_ENFORCE(true, "Unknown 'place'.");
}
offset_ = 0; offset_ = 0;
} }
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
...@@ -88,7 +84,7 @@ class Tensor { ...@@ -88,7 +84,7 @@ class Tensor {
platform::is_cpu_place(dst_place), platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now."); "Tensor::CopyFrom only support CPU now.");
src.CheckDims<T>(); src.CheckDims<T>();
size_t size = src.numel_ * sizeof(T); size_t size = product(src.dims_) * sizeof(T);
set_dims(src.dims()); set_dims(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>()); const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place)); void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
...@@ -122,7 +118,6 @@ class Tensor { ...@@ -122,7 +118,6 @@ class Tensor {
return; return;
} }
dims_ = dims; dims_ = dims;
numel_ = product(dims_);
} }
DDim dims() const { return dims_; } DDim dims() const { return dims_; }
...@@ -170,16 +165,15 @@ class Tensor { ...@@ -170,16 +165,15 @@ class Tensor {
inline void CheckDims() const { inline void CheckDims() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first."); "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_, PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."); "first to re-allocate memory.");
} }
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated. std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_; DDim dims_;
size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area. size_t offset_; // marks the begin of tensor data area.
}; // namespace framework };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册