提交 dcfcf687 编写于 作者: F fengjiayi

Refactor Tensor::CopyFrom()

1. Add template T which indicates data type to `CopyFrom()`, `Slice()`
and `ShareData()` functions. This makes `CopyData()` code much clearer.

2. Add `set_dim()`.

3. `product(DDim)` transforms `DDim` to `vector<int>` first and then calculate
its product. That might be quite slow. For `product(dims_)` is frequently
used in Tensor, we add a mumber variable `numel_` as a cache of the
product result.
TODO: refactor `product()` to make it more efficient.

4. Unable Tensor::operator=

5. Remove the limit of POD type, because `float16` and `int8` are not POD type.
上级 a1dc4311
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <type_traits>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h" #include "paddle/framework/enforce.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
...@@ -28,15 +27,15 @@ namespace framework { ...@@ -28,15 +27,15 @@ namespace framework {
class Tensor { class Tensor {
public: public:
Tensor() : offset_(0) { numel_ = product(dims_); } Tensor() : numel_(0), offset_(0) {}
Tensor& operator=(const Tensor& src) = delete; Tensor& operator=(const Tensor& src) = delete;
template <typename T> template <typename T>
const T* data() const { const T* data() const {
CheckDimsValidity(); CheckDimsValidity<T>();
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->Ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
} }
template <typename T> template <typename T>
...@@ -51,35 +50,40 @@ class Tensor { ...@@ -51,35 +50,40 @@ class Tensor {
"Tensor::numel_ must be larger than zero to call " "Tensor::numel_ must be larger than zero to call "
"Tensor::mutable_data."); "Tensor::mutable_data.");
if (holder_ == nullptr || if (holder_ == nullptr ||
!(holder_->Place() == !(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */ place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < numel_ * sizeof(T) + offset_) { || holder_->size() < numel_ * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T))); holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T)));
offset_ = 0; offset_ = 0;
} }
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->Ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
template <typename T>
void ShareDataFrom(const Tensor& src) { void ShareDataFrom(const Tensor& src) {
src.CheckDimsValidity(); src.CheckDimsValidity<T>();
holder_ = src.holder_; holder_ = src.holder_;
dims_ = src.dims(); set_dims(src.dims());
numel_ = src.numel_;
offset_ = src.offset_; offset_ = src.offset_;
} }
template <typename T>
void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) { void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) {
src.CheckDimsValidity(); PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
size_t size = src.numel_ * src.holder_->TypeSize(); platform::is_cpu_place(dst_place),
holder_.reset(src.holder_->Clone(src.offset_, size, dst_place)); "Tensor::CopyFrom only support CPU now.");
dims_ = src.dims(); src.CheckDimsValidity<T>();
numel_ = src.numel_; size_t size = src.numel_ * sizeof(T);
offset_ = 0; set_dims(src.dims());
void* src_ptr = static_cast<void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
memcpy(dst_ptr, src_ptr, size);
} }
template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const { Tensor Slice(const int& begin_idx, const int& end_idx) const {
CheckDimsValidity(); CheckDimsValidity<T>();
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound."); "Slice index is less than zero or out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx, PADDLE_ENFORCE(begin_idx < end_idx,
...@@ -95,7 +99,7 @@ class Tensor { ...@@ -95,7 +99,7 @@ class Tensor {
DDim dst_dims = dims_; DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx; dst_dims[0] = end_idx - begin_idx;
dst.set_dims(dst_dims); dst.set_dims(dst_dims);
dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize(); dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst; return dst;
} }
...@@ -115,12 +119,9 @@ class Tensor { ...@@ -115,12 +119,9 @@ class Tensor {
// parameter of Variable. // parameter of Variable.
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual void* Ptr() const = 0; virtual void* ptr() const = 0;
virtual paddle::platform::Place Place() const = 0; virtual paddle::platform::Place place() const = 0;
virtual size_t Size() const = 0; virtual size_t size() const = 0;
virtual size_t TypeSize() const = 0;
virtual Placeholder* Clone(size_t begin, size_t size,
paddle::platform::Place place) const = 0;
}; };
template <typename T> template <typename T>
...@@ -144,32 +145,20 @@ class Tensor { ...@@ -144,32 +145,20 @@ class Tensor {
place_(place), place_(place),
size_(size) {} size_(size) {}
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); } virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t Size() const { return size_; } virtual size_t size() const { return size_; }
virtual paddle::platform::Place Place() const { return place_; } virtual paddle::platform::Place place() const { return place_; }
virtual size_t TypeSize() const { return sizeof(T); }
// TODO: Clone only support CPU now. GPU support is needed.
virtual Placeholder* Clone(size_t begin, size_t size,
paddle::platform::Place place) const {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(place_) &&
paddle::platform::is_cpu_place(place),
"PlaceholderImpl::Clone only support CPU now.");
PlaceholderImpl<T>* dst = new PlaceholderImpl<T>(place, size);
void* begin_ptr =
reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(Ptr()) + begin);
memcpy(dst->Ptr(), begin_ptr, size);
return dst;
}
std::unique_ptr<T, Deleter> ptr_; std::unique_ptr<T, Deleter> ptr_;
paddle::platform::Place place_; // record the place of ptr_. paddle::platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block. size_t size_; // size of the memory block.
}; };
inline void CheckDimsValidity() { template <typename T>
inline void CheckDimsValidity() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first."); "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->Size() > numel_ * sizeof(T) + offset_, PADDLE_ENFORCE(holder_->size() > numel_ * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."); "first to re-allocate memory.");
} }
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
TEST(Tensor, Dims) { TEST(Tensor, Dims) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor tt(make_ddim({2, 3, 4})); Tensor tt;
tt.set_dims(make_ddim({2, 3, 4}));
DDim dims = tt.dims(); DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3); ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
...@@ -35,7 +36,7 @@ TEST(Tensor, DataAssert) { ...@@ -35,7 +36,7 @@ TEST(Tensor, DataAssert) {
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr has not been initialized. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
const char* what = err.what(); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]); ASSERT_EQ(what[i], msg[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册