提交 55d30172 编写于 作者: F fengjiayi

Simplify Tensor implimentation

ATTENTION: some interfaces changed:
1. void Tensor::set_dims(const DDim& dims) ==> void Tensor::Resize(const DDim& dims).
2. void Tensor::ShareDataFrom(const Tensor& src)  ==> void Tensor::ShareDataWith(const Tensor& src)
3. DDim Tensor::dims() const ==> const DDim& Tensor::dims() const
上级 b90780c3
...@@ -40,21 +40,21 @@ class Tensor { ...@@ -40,21 +40,21 @@ class Tensor {
template <typename T> template <typename T>
const T* data() const { const T* data() const {
CheckDims<T>(); EnforceSufficientMemory<T>();
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
} }
template <typename T> template <typename T>
T* raw_data() const { T* raw_data() const {
CheckDims<T>(); EnforceSufficientMemory<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
template <typename T> template <typename T>
T* mutable_data(DDim dims, platform::Place place) { T* mutable_data(DDim dims, platform::Place place) {
set_dims(dims); Resize(dims);
return mutable_data<T>(place); return mutable_data<T>(place);
} }
...@@ -147,11 +147,9 @@ class Tensor { ...@@ -147,11 +147,9 @@ class Tensor {
} }
template <typename T> template <typename T>
void ShareDataFrom(const Tensor& src) { void ShareDataWith(const Tensor& src) {
src.CheckDims<T>(); src.EnforceSufficientMemory<T>();
holder_ = src.holder_; *this = src;
set_dims(src.dims());
offset_ = src.offset_;
} }
template <typename T> template <typename T>
...@@ -159,9 +157,9 @@ class Tensor { ...@@ -159,9 +157,9 @@ class Tensor {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place), platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now."); "Tensor::CopyFrom only support CPU now.");
src.CheckDims<T>(); src.EnforceSufficientMemory<T>();
size_t size = product(src.dims_) * sizeof(T); size_t size = product(src.dims_) * sizeof(T);
set_dims(src.dims()); Resize(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>()); const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place)); void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
memcpy(dst_ptr, src_ptr, size); memcpy(dst_ptr, src_ptr, size);
...@@ -169,34 +167,25 @@ class Tensor { ...@@ -169,34 +167,25 @@ class Tensor {
template <typename T> template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const { Tensor Slice(const int& begin_idx, const int& end_idx) const {
CheckDims<T>(); EnforceSufficientMemory<T>();
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
"Slice index is less than zero or out of bound."); PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx, PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index."); "Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1."); PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
std::vector<int> d = vectorize(dims_); int base = product(dims_) / dims_[0];
int base = 1;
for (size_t i = 1; i < d.size(); ++i) {
base *= d[i];
}
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
DDim dst_dims = dims_; DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx; dst_dims[0] = end_idx - begin_idx;
dst.set_dims(dst_dims); dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T); dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst; return dst;
} }
void set_dims(const DDim& dims) { void Resize(const DDim& dims) { dims_ = dims; }
if (dims == dims_) {
return;
}
dims_ = dims;
}
DDim dims() const { return dims_; } const DDim& dims() const { return dims_; }
private: private:
// Placeholder hides type T, so it doesn't appear as a template // Placeholder hides type T, so it doesn't appear as a template
...@@ -211,21 +200,9 @@ class Tensor { ...@@ -211,21 +200,9 @@ class Tensor {
template <typename T, typename PlaceType> template <typename T, typename PlaceType>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
private:
template <typename PType>
class Deleter {
public:
Deleter(PType place) : place_(place) {}
void operator()(T* ptr) { memory::Free(place_, static_cast<void*>(ptr)); }
private:
PType place_;
};
public:
PlaceholderImpl(PlaceType place, size_t size) PlaceholderImpl(PlaceType place, size_t size)
: ptr_(static_cast<T*>(memory::Alloc(place, size)), : ptr_(static_cast<T*>(memory::Alloc(place, size)),
Deleter<PlaceType>(place)), memory::PodDeleter<T, PlaceType>(place)),
place_(place), place_(place),
size_(size) {} size_(size) {}
...@@ -234,13 +211,13 @@ class Tensor { ...@@ -234,13 +211,13 @@ class Tensor {
virtual paddle::platform::Place place() const { return place_; } virtual paddle::platform::Place place() const { return place_; }
virtual std::type_index type() const { return std::type_index(typeid(T)); } virtual std::type_index type() const { return std::type_index(typeid(T)); }
std::unique_ptr<T, Deleter<PlaceType>> ptr_; std::unique_ptr<T, memory::PodDeleter<T, PlaceType>> ptr_;
platform::Place place_; // record the place of ptr_. platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block. size_t size_; // size of the memory block.
}; };
template <typename T> template <typename T>
inline void CheckDims() const { inline void EnforceSufficientMemory() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first."); "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_, PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
...@@ -250,7 +227,11 @@ class Tensor { ...@@ -250,7 +227,11 @@ class Tensor {
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated. std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_; DDim dims_;
size_t offset_; // marks the begin of tensor data area. // A PlaceHolder may be shared by more than one tensor. Some of them may be
// slices of the others. So the offset_ is introduced here to indicate the
// byte offset between PlaceHolder::ptr_ and where tensor's data really
// begins.
size_t offset_;
template <bool less, size_t i, typename... args> template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl; friend struct paddle::pybind::details::CastToPyBufferImpl;
}; };
......
...@@ -19,7 +19,7 @@ TEST(Tensor, Dims) { ...@@ -19,7 +19,7 @@ TEST(Tensor, Dims) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor tt; Tensor tt;
tt.set_dims(make_ddim({2, 3, 4})); tt.Resize(make_ddim({2, 3, 4}));
DDim dims = tt.dims(); DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3); ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
...@@ -97,7 +97,7 @@ TEST(Tensor, MutableData) { ...@@ -97,7 +97,7 @@ TEST(Tensor, MutableData) {
#endif #endif
} }
TEST(Tensor, ShareDataFrom) { TEST(Tensor, ShareDataWith) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
{ {
...@@ -106,7 +106,7 @@ TEST(Tensor, ShareDataFrom) { ...@@ -106,7 +106,7 @@ TEST(Tensor, ShareDataFrom) {
// Try to share data form uninitialized tensor // Try to share data form uninitialized tensor
bool caught = false; bool caught = false;
try { try {
dst_tensor.ShareDataFrom<float>(src_tensor); dst_tensor.ShareDataWith<float>(src_tensor);
} catch (EnforceNotMet err) { } catch (EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
...@@ -119,7 +119,7 @@ TEST(Tensor, ShareDataFrom) { ...@@ -119,7 +119,7 @@ TEST(Tensor, ShareDataFrom) {
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace()); src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
dst_tensor.ShareDataFrom<int>(src_tensor); dst_tensor.ShareDataWith<int>(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
...@@ -128,7 +128,7 @@ TEST(Tensor, ShareDataFrom) { ...@@ -128,7 +128,7 @@ TEST(Tensor, ShareDataFrom) {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace()); src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
dst_tensor.ShareDataFrom<int>(src_tensor); dst_tensor.ShareDataWith<int>(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
#endif #endif
......
...@@ -28,5 +28,15 @@ void Free(Place, void*); ...@@ -28,5 +28,15 @@ void Free(Place, void*);
template <class Place> template <class Place>
size_t Used(Place); size_t Used(Place);
template <typename T, typename PlaceType>
class PodDeleter {
public:
PodDeleter(PlaceType place) : place_(place) {}
void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); }
private:
PlaceType place_;
};
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,7 @@ protected: ...@@ -31,7 +31,7 @@ protected:
"Inputs/Outputs of AddOp must all be set"); "Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of Add Op's dimension must be same."); "Two input of Add Op's dimension must be same.");
outputs[0]->set_dims(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
......
...@@ -33,7 +33,7 @@ protected: ...@@ -33,7 +33,7 @@ protected:
dim0[1] == dim1[0], dim0[1] == dim1[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output");
outputs[0]->set_dims({dim0[0], dim1[1]}); outputs[0]->Resize({dim0[0], dim1[1]});
} }
}; };
......
...@@ -30,7 +30,7 @@ protected: ...@@ -30,7 +30,7 @@ protected:
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1");
outputs[0]->set_dims(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
......
...@@ -24,7 +24,7 @@ protected: ...@@ -24,7 +24,7 @@ protected:
const std::vector<framework::Tensor *> &outputs) const override { const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output");
outputs[0]->set_dims(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
......
...@@ -25,7 +25,7 @@ protected: ...@@ -25,7 +25,7 @@ protected:
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax");
outputs[0]->set_dims(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
......
...@@ -42,7 +42,7 @@ PYBIND11_PLUGIN(core) { ...@@ -42,7 +42,7 @@ PYBIND11_PLUGIN(core) {
[](const pd::Tensor& self) { return pd::vectorize(self.dims()); }) [](const pd::Tensor& self) { return pd::vectorize(self.dims()); })
.def("set_dims", .def("set_dims",
[](pd::Tensor& self, const std::vector<int>& dim) { [](pd::Tensor& self, const std::vector<int>& dim) {
self.set_dims(pd::make_ddim(dim)); self.Resize(pd::make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self) { [](pd::Tensor& self) {
......
...@@ -86,7 +86,7 @@ void PyTensorSetFromArray( ...@@ -86,7 +86,7 @@ void PyTensorSetFromArray(
dims.push_back((int)array.shape()[i]); dims.push_back((int)array.shape()[i]);
} }
self.set_dims(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace()); auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace());
std::memcpy(dst, array.data(), sizeof(T) * array.size()); std::memcpy(dst, array.data(), sizeof(T) * array.size());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册