提交 dcfcf687 编写于 作者: F fengjiayi

Refactor Tensor::CopyFrom()

1. Add template T which indicates data type to `CopyFrom()`, `Slice()`
and `ShareData()` functions. This makes `CopyData()` code much clearer.

2. Add `set_dim()`.

3. `product(DDim)` transforms `DDim` to `vector<int>` first and then calculate
its product. That might be quite slow. For `product(dims_)` is frequently
used in Tensor, we add a mumber variable `numel_` as a cache of the
product result.
TODO: refactor `product()` to make it more efficient.

4. Unable Tensor::operator=

5. Remove the limit of POD type, because `float16` and `int8` are not POD type.
上级 a1dc4311
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <cstdint>
#include <cstring>
#include <memory>
#include <type_traits>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/memory/memory.h"
......@@ -28,15 +27,15 @@ namespace framework {
class Tensor {
public:
Tensor() : offset_(0) { numel_ = product(dims_); }
Tensor() : numel_(0), offset_(0) {}
Tensor& operator=(const Tensor& src) = delete;
template <typename T>
const T* data() const {
CheckDimsValidity();
CheckDimsValidity<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->Ptr()) + offset_);
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
template <typename T>
......@@ -51,35 +50,40 @@ class Tensor {
"Tensor::numel_ must be larger than zero to call "
"Tensor::mutable_data.");
if (holder_ == nullptr ||
!(holder_->Place() ==
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < numel_ * sizeof(T) + offset_) {
|| holder_->size() < numel_ * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T)));
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->Ptr()) +
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
void ShareDataFrom(const Tensor& src) {
src.CheckDimsValidity();
src.CheckDimsValidity<T>();
holder_ = src.holder_;
dims_ = src.dims();
numel_ = src.numel_;
set_dims(src.dims());
offset_ = src.offset_;
}
template <typename T>
void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) {
src.CheckDimsValidity();
size_t size = src.numel_ * src.holder_->TypeSize();
holder_.reset(src.holder_->Clone(src.offset_, size, dst_place));
dims_ = src.dims();
numel_ = src.numel_;
offset_ = 0;
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
src.CheckDimsValidity<T>();
size_t size = src.numel_ * sizeof(T);
set_dims(src.dims());
void* src_ptr = static_cast<void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
memcpy(dst_ptr, src_ptr, size);
}
template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const {
CheckDimsValidity();
CheckDimsValidity<T>();
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
......@@ -95,7 +99,7 @@ class Tensor {
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.set_dims(dst_dims);
dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize();
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
......@@ -115,12 +119,9 @@ class Tensor {
// parameter of Variable.
struct Placeholder {
virtual ~Placeholder() {}
virtual void* Ptr() const = 0;
virtual paddle::platform::Place Place() const = 0;
virtual size_t Size() const = 0;
virtual size_t TypeSize() const = 0;
virtual Placeholder* Clone(size_t begin, size_t size,
paddle::platform::Place place) const = 0;
virtual void* ptr() const = 0;
virtual paddle::platform::Place place() const = 0;
virtual size_t size() const = 0;
};
template <typename T>
......@@ -144,32 +145,20 @@ class Tensor {
place_(place),
size_(size) {}
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t Size() const { return size_; }
virtual paddle::platform::Place Place() const { return place_; }
virtual size_t TypeSize() const { return sizeof(T); }
// TODO: Clone only support CPU now. GPU support is needed.
virtual Placeholder* Clone(size_t begin, size_t size,
paddle::platform::Place place) const {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(place_) &&
paddle::platform::is_cpu_place(place),
"PlaceholderImpl::Clone only support CPU now.");
PlaceholderImpl<T>* dst = new PlaceholderImpl<T>(place, size);
void* begin_ptr =
reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(Ptr()) + begin);
memcpy(dst->Ptr(), begin_ptr, size);
return dst;
}
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t size() const { return size_; }
virtual paddle::platform::Place place() const { return place_; }
std::unique_ptr<T, Deleter> ptr_;
paddle::platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};
inline void CheckDimsValidity() {
template <typename T>
inline void CheckDimsValidity() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->Size() > numel_ * sizeof(T) + offset_,
PADDLE_ENFORCE(holder_->size() > numel_ * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
......
......@@ -18,7 +18,8 @@
TEST(Tensor, Dims) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor tt(make_ddim({2, 3, 4}));
Tensor tt;
tt.set_dims(make_ddim({2, 3, 4}));
DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
......@@ -35,7 +36,7 @@ TEST(Tensor, DataAssert) {
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr has not been initialized. Call Tensor::mutable_data first.";
"Tenosr holds no memory. Call Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册