提交 6129ab42 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #2964 from Canpio/dev_refactor_tensor

Simplify Tensor implementation
......@@ -48,25 +48,27 @@ class Tensor {
template <typename T>
const T* data() const {
CheckDims<T>();
EnforceSufficientMemory<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
template <typename T>
T* data() {
CheckDims<T>();
EnforceSufficientMemory<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, platform::Place place) {
set_dims(dims);
Resize(dims);
return mutable_data<T>(place);
}
template <typename T>
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
......@@ -95,11 +97,9 @@ class Tensor {
}
template <typename T>
void ShareDataFrom(const Tensor& src) {
src.CheckDims<T>();
holder_ = src.holder_;
set_dims(src.dims());
offset_ = src.offset_;
void ShareDataWith(const Tensor& src) {
src.EnforceSufficientMemory<T>();
*this = src;
}
template <typename T>
......@@ -107,9 +107,9 @@ class Tensor {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
src.CheckDims<T>();
src.EnforceSufficientMemory<T>();
size_t size = product(src.dims_) * sizeof(T);
set_dims(src.dims());
Resize(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
memcpy(dst_ptr, src_ptr, size);
......@@ -117,34 +117,25 @@ class Tensor {
template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const {
CheckDims<T>();
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound.");
EnforceSufficientMemory<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
std::vector<int> d = vectorize(dims_);
int base = 1;
for (size_t i = 1; i < d.size(); ++i) {
base *= d[i];
}
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.set_dims(dst_dims);
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
void set_dims(const DDim& dims) {
if (dims == dims_) {
return;
}
dims_ = dims;
}
void Resize(const DDim& dims) { dims_ = dims; }
DDim dims() const { return dims_; }
const DDim& dims() const { return dims_; }
private:
// Placeholder hides type T, so it doesn't appear as a template
......@@ -159,21 +150,9 @@ class Tensor {
template <typename T, typename PlaceType>
struct PlaceholderImpl : public Placeholder {
private:
template <typename PType>
class Deleter {
public:
Deleter(PType place) : place_(place) {}
void operator()(T* ptr) { memory::Free(place_, static_cast<void*>(ptr)); }
private:
PType place_;
};
public:
PlaceholderImpl(PlaceType place, size_t size)
: ptr_(static_cast<T*>(memory::Alloc(place, size)),
Deleter<PlaceType>(place)),
memory::PODDeleter<T, PlaceType>(place)),
place_(place),
size_(size) {}
......@@ -182,13 +161,13 @@ class Tensor {
virtual paddle::platform::Place place() const { return place_; }
virtual std::type_index type() const { return std::type_index(typeid(T)); }
std::unique_ptr<T, Deleter<PlaceType>> ptr_;
std::unique_ptr<T, memory::PODDeleter<T, PlaceType>> ptr_;
platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};
template <typename T>
inline void CheckDims() const {
inline void EnforceSufficientMemory() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
......@@ -198,7 +177,11 @@ class Tensor {
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_;
size_t offset_; // marks the begin of tensor data area.
// A PlaceHolder may be shared by more than one tensor. Some of them may be
// slices of the others. So the offset_ is introduced here to indicate the
// byte offset between PlaceHolder::ptr_ and where tensor's data really
// begins.
size_t offset_;
};
} // namespace framework
......
......@@ -19,7 +19,7 @@ TEST(Tensor, Dims) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor tt;
tt.set_dims(make_ddim({2, 3, 4}));
tt.Resize(make_ddim({2, 3, 4}));
DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
......@@ -97,7 +97,7 @@ TEST(Tensor, MutableData) {
#endif
}
TEST(Tensor, ShareDataFrom) {
TEST(Tensor, ShareDataWith) {
using namespace paddle::framework;
using namespace paddle::platform;
{
......@@ -106,7 +106,7 @@ TEST(Tensor, ShareDataFrom) {
// Try to share data form uninitialized tensor
bool caught = false;
try {
dst_tensor.ShareDataFrom<float>(src_tensor);
dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) {
caught = true;
std::string msg =
......@@ -119,7 +119,7 @@ TEST(Tensor, ShareDataFrom) {
ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
dst_tensor.ShareDataFrom<int>(src_tensor);
dst_tensor.ShareDataWith<int>(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
......@@ -128,7 +128,7 @@ TEST(Tensor, ShareDataFrom) {
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
dst_tensor.ShareDataFrom<int>(src_tensor);
dst_tensor.ShareDataWith<int>(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
#endif
......
......@@ -28,5 +28,17 @@ void Free(Place, void*);
template <class Place>
size_t Used(Place);
template <typename T, /* must be POD types */
typename Place /* platform::GPUPlace or platform::CPUPlace */,
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
class PODDeleter {
public:
PODDeleter(Place place) : place_(place) {}
void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); }
private:
Place place_;
};
} // namespace memory
} // namespace paddle
......@@ -31,7 +31,7 @@ protected:
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of Add Op's dimension must be same.");
outputs[0]->set_dims(inputs[0]->dims());
outputs[0]->Resize(inputs[0]->dims());
}
};
......
......@@ -35,7 +35,7 @@ protected:
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->set_dims(framework::make_ddim({inputs[0]->dims()[0]}));
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]}));
}
};
......
......@@ -33,7 +33,7 @@ protected:
dim0[1] == dim1[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output");
outputs[0]->set_dims({dim0[0], dim1[1]});
outputs[0]->Resize({dim0[0], dim1[1]});
}
};
......
......@@ -30,7 +30,7 @@ protected:
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1");
outputs[0]->set_dims(inputs[0]->dims());
outputs[0]->Resize(inputs[0]->dims());
}
};
......
......@@ -31,7 +31,7 @@ protected:
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of SGD Op's dimension must be same.");
outputs[0]->set_dims(inputs[0]->dims());
outputs[0]->Resize(inputs[0]->dims());
}
};
......
......@@ -24,7 +24,7 @@ protected:
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output");
outputs[0]->set_dims(inputs[0]->dims());
outputs[0]->Resize(inputs[0]->dims());
}
};
......
......@@ -27,7 +27,7 @@ protected:
"The input of softmax op must be matrix");
PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax");
outputs[0]->set_dims(inputs[0]->dims());
outputs[0]->Resize(inputs[0]->dims());
}
};
......
......@@ -46,7 +46,7 @@ PYBIND11_PLUGIN(core) {
[](const pd::Tensor& self) { return pd::vectorize(self.dims()); })
.def("set_dims",
[](pd::Tensor& self, const std::vector<int>& dim) {
self.set_dims(pd::make_ddim(dim));
self.Resize(pd::make_ddim(dim));
})
.def("alloc_float",
[](pd::Tensor& self) {
......
......@@ -86,7 +86,7 @@ void PyTensorSetFromArray(
dims.push_back((int)array.shape()[i]);
}
self.set_dims(framework::make_ddim(dims));
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace());
std::memcpy(dst, array.data(), sizeof(T) * array.size());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册