提交 efd98097 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #2926 from Canpio/dev_enable_tensor_test

refine tensor's conditional compilation, remove Tensor::numel_ and add DDim::size()
......@@ -117,6 +117,8 @@ int DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var);
}
ssize_t DDim::size() const { return arity(*this); }
bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) {
return false;
......
......@@ -50,6 +50,8 @@ struct DDim {
DDimVar getVar() { return var; }
ssize_t size() const;
bool operator==(DDim d) const;
bool operator!=(DDim d) const;
......
......@@ -49,6 +49,7 @@ TEST(DDim, Equality) {
// arity of a DDim
EXPECT_EQ(paddle::framework::arity(ddim), 3);
EXPECT_EQ(ddim.size(), 3);
// product of a DDim
EXPECT_EQ(paddle::framework::product(vddim), 45);
......
......@@ -27,7 +27,7 @@ namespace framework {
class Tensor {
public:
Tensor() : numel_(0), offset_(0) {}
Tensor() : offset_(0) {}
template <typename T>
const T* data() const {
......@@ -44,30 +44,26 @@ class Tensor {
template <typename T>
T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(numel_ > 0,
"Tensor::numel_ must be larger than zero to call "
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
if (holder_ == nullptr ||
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) {
|| holder_->size() < product(dims_) * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
} else if (platform::is_gpu_place(place)) {
#ifdef __CUDACC__
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
#else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device.");
#endif
} else {
PADDLE_ENFORCE(true, "Unknown 'place'.");
}
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
......@@ -88,7 +84,7 @@ class Tensor {
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
src.CheckDims<T>();
size_t size = src.numel_ * sizeof(T);
size_t size = product(src.dims_) * sizeof(T);
set_dims(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
......@@ -122,7 +118,6 @@ class Tensor {
return;
}
dims_ = dims;
numel_ = product(dims_);
}
DDim dims() const { return dims_; }
......@@ -170,16 +165,15 @@ class Tensor {
inline void CheckDims() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_,
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_;
size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area.
}; // namespace framework
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册