提交 f86d35a2 编写于 作者: T typhoonzero

add sharable tensor

上级 494c262a
...@@ -98,6 +98,9 @@ class Tensor { ...@@ -98,6 +98,9 @@ class Tensor {
/*! The internal of two tensors share the same memory block. */ /*! The internal of two tensors share the same memory block. */
inline Tensor& ShareDataWith(const Tensor& src); inline Tensor& ShareDataWith(const Tensor& src);
/*! Share part of the memory of the two tensors */
inline Tensor& ShareDataWith(Tensor* src, size_t offset);
/** /**
* @brief Return a sub-tensor of the given tensor. * @brief Return a sub-tensor of the given tensor.
* *
...@@ -176,6 +179,32 @@ class Tensor { ...@@ -176,6 +179,32 @@ class Tensor {
std::type_index type_; std::type_index type_;
}; };
template <typename Place>
struct SharedPlaceholderImpl : public Placeholder {
SharedPlaceholderImpl(Place place, uint8_t* data, size_t size,
std::type_index type)
: ptr_(data), place_(place), size_(size), type_(type) {}
virtual size_t size() const { return size_; }
virtual platform::Place place() const { return place_; }
virtual void* ptr() const { return static_cast<void*>(ptr_); }
virtual std::type_index type() const { return type_; }
virtual void set_type(std::type_index type) { type_ = type; }
virtual void set_place(platform::Place place) { place_ = place; }
/*! the pointer of memory block. */
uint8_t* ptr_;
/*! the place of memory block. */
platform::Place place_;
/*! the size of memory block. */
size_t size_;
/* the current type of memory */
std::type_index type_;
};
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<Placeholder> holder_; std::shared_ptr<Placeholder> holder_;
......
...@@ -162,6 +162,37 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) { ...@@ -162,6 +162,37 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
return *this; return *this;
} }
inline Tensor& Tensor::ShareDataWith(Tensor* src, size_t offset) {
// NOTE: data size is determined by current tensor shape and data type
src->check_memory_size();
PADDLE_ENFORCE_EQ(src->type(), this->type(),
"tensor data type must be the same when sharing data");
auto place = src->place();
auto type = src->type();
size_t size = src->numel() * SizeOfType(src->type());
auto* ref = static_cast<uint8_t*>(src->mutable_data(place)) + offset;
if (platform::is_cpu_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), ref, size, type));
} else if (platform::is_gpu_place(place) ||
platform::is_cuda_pinned_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(
"CUDAPlace or CUDAPinnedPlace is not supported in CPU-only mode.");
}
#else
if (platform::is_gpu_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), ref, size, type));
} else if (platform::is_cuda_pinned_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CUDAPinnedPlace>(
boost::get<platform::CUDAPinnedPlace>(place), ref, size, type));
}
}
#endif
return *this;
}
inline Tensor Tensor::Slice(int begin_idx, int end_idx) const { inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0, PADDLE_ENFORCE_GE(begin_idx, 0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册