未验证 提交 5c73a6ea 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[Unify Tensors PR #5] framework::Tensor inherits from DenseTensor,test=allcases (#38632)

* Added shared_ptr<Allocation> member & corresponding interfaces to Storage

* Removed original pten::Allocation from Storage and adjusted the interfaces accordingly

* Fixed issues with storage offset

* Used place to malloc allocation for TensorStorage

* [Unify Tensors PR #3]Ported framework::Tensor interfaces to pten::DenseTensor

* Fixed issues with place

* Added comments

* Moved mutable_data with stream argument to DenseTensor

* Added set_offset interface

* Fixed CI issues,test=allcases

* [Unify Tensors PR #4] Port LoDTensor interfaces to DenseTensor

* Removed friend class EigenTensor/EigenMatrix/EigenVector from Tensor

* Modified framework::Tensor to inherit from DenseTensor

* Reverted changes too pten_layout() interface

* Removed friend classes

* Rearranged cfunction calls from tensor.data<void>() to tensor.data()

* Fixed CI issues

* Fixed lite issues

* Fixed data() interface issues,test=allcases

* Resolved IsInitialized() issues

* Fixed ResetHolder() issues

* Fixed MKLDNN & Storage issues

* Resolved ShareBufferWith() issues

* Fixed LoD issues
上级 046553c7
......@@ -103,19 +103,17 @@ void SerializeLodTensor(framework::Variable* var,
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data<void>()),
data_len);
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
memory::Copy(
platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(),
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
......@@ -147,19 +145,17 @@ void SerializeSelectedRows(framework::Variable* var,
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data<void>()),
data_len);
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
memory::Copy(
platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(),
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
......
......@@ -34,7 +34,7 @@ int GetMicroId(const platform::DeviceContext& ctx,
auto micro_id = -1;
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (platform::is_cpu_place(tensor->place())) {
auto data = reinterpret_cast<const float*>(tensor->data<void>());
auto data = reinterpret_cast<const float*>(tensor->data());
micro_id = static_cast<int>(data[0]);
} else {
#ifdef PADDLE_WITH_CUDA
......@@ -43,11 +43,10 @@ int GetMicroId(const platform::DeviceContext& ctx,
char* temp_ptr = temp.data();
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
memory::Copy(
platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(),
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
#endif
......
......@@ -240,7 +240,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
platform::errors::InvalidArgument(
"Not find variable microbatch_id in scope."));
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto data = reinterpret_cast<const float*>(tensor->data<void>());
auto data = reinterpret_cast<const float*>(tensor->data());
auto micro_id = static_cast<int>(data[0]);
int minibatch_index = micro_id / 10;
......
......@@ -91,7 +91,7 @@ endif()
cc_test(copy_same_tensor_test SRCS copy_same_tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(mixed_vector SRCS mixed_vector.cc DEPS device_context)
cc_library(mixed_vector SRCS mixed_vector.cc DEPS device_context place memory)
if(WITH_GPU)
nv_test(mixed_vector_test SRCS mixed_vector_test.cc mixed_vector_test.cu DEPS mixed_vector place memory device_context tensor)
......
......@@ -77,8 +77,8 @@ static bool CopySameTensorTestMain(const DDim &dims,
TensorCopySync(src_tensor, platform::CPUPlace(), &dst_cpu_tensor);
}
const void *ground_truth_ptr = src_cpu_tensor.data<void>();
const void *result_ptr = dst_cpu_tensor.data<void>();
const void *ground_truth_ptr = src_cpu_tensor.data();
const void *result_ptr = dst_cpu_tensor.data();
size_t byte_num = product(dims) * sizeof(T);
return std::memcmp(ground_truth_ptr, result_ptr, byte_num) == 0;
}
......
......@@ -45,7 +45,6 @@ void TransformData(const OpKernelType &expected_kernel_type,
Tensor out;
const DataLayout lin = kernel_type_for_var.data_layout_;
const DataLayout lout = expected_kernel_type.data_layout_;
// do layout transform
if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -153,7 +153,7 @@ void AllReduceOpHandle::AllReduceImpl(
"The place type of tensors of the same variable "
"in different local scopes should be equal."));
lod_tensor_data.emplace_back(lod_tensor.data<void>());
lod_tensor_data.emplace_back(lod_tensor.data());
places.emplace_back(lod_tensor.place());
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
......@@ -225,7 +225,7 @@ void AllReduceOpHandle::AllReduceFunc(
->GetMutable<LoDTensor>();
// Reduce All Tensor to trg in CPU
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
ReduceBufferData func(lod_tensor_data, trg.data(), numel);
VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
......@@ -235,9 +235,9 @@ void AllReduceOpHandle::AllReduceFunc(
size_t size = numel * SizeOfType(trg.type());
RunAndRecordEvent(p, [&trg, var, p, size] {
auto dst_ptr = var->GetMutable<framework::LoDTensor>()->data<void>();
auto dst_ptr = var->GetMutable<framework::LoDTensor>()->data();
platform::CPUPlace cpu_place;
memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data<void>(), size);
memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data(), size);
});
}
}
......
......@@ -101,7 +101,7 @@ void BroadcastOpHandle::BroadcastOneVar(
void *send_recv_buffer = nullptr;
if (root_id == dst_id) {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
send_recv_buffer = const_cast<void *>(in_tensor.data());
out_handle = out_var_handle;
} else {
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
......@@ -162,7 +162,7 @@ void BroadcastOpHandle::BroadcastOneVar(
void *send_recv_buffer = nullptr;
if (root_id == dst_id) {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
send_recv_buffer = const_cast<void *>(in_tensor.data());
out_handle = out_var_handle;
} else {
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
......
......@@ -220,17 +220,17 @@ void FusedAllReduceOpHandle::FusedAllReduceFunc(
g_tensor.begin(), g_tensor.end(),
[](const std::pair<std::string, const LoDTensor *> &grad1,
const std::pair<std::string, const LoDTensor *> &grad2) -> bool {
return grad1.second->data<void>() < grad2.second->data<void>();
return grad1.second->data() < grad2.second->data();
});
size_t size_of_dtype = framework::SizeOfType(dtype);
for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *cur_address = g_tensor.at(k - 1).second->data<void>();
const void *cur_address = g_tensor.at(k - 1).second->data();
int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = platform::Alignment(len * size_of_dtype, places_[0]);
void *infer_next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(cur_address) + offset);
const void *next_address = g_tensor.at(k).second->data<void>();
const void *next_address = g_tensor.at(k).second->data();
VLOG(10) << string::Sprintf(
"Input[%d](%s) address: 0X%02x, Input[%d](%s) address: 0X%02x, Infer "
......@@ -267,7 +267,7 @@ void FusedAllReduceOpHandle::FusedAllReduceFunc(
std::vector<const void *> lod_tensor_data;
lod_tensor_data.reserve(place_num);
for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
auto data = grads_tensor.at(scope_idx).at(0).second->data<void>();
auto data = grads_tensor.at(scope_idx).at(0).second->data();
lod_tensor_data.emplace_back(data);
}
std::vector<std::string> grad_var_names;
......
......@@ -159,7 +159,7 @@ void ReduceOpHandle::RunImpl() {
VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) {
if (reduce_sum_trg.data() != trg->data()) {
TensorCopy(reduce_sum_trg, platform::CPUPlace(), trg);
}
}
......@@ -181,7 +181,7 @@ void ReduceOpHandle::RunImpl() {
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
void *buffer = const_cast<void *>(lod_tensor.data<void>());
void *buffer = const_cast<void *>(lod_tensor.data());
void *recvbuffer = nullptr;
if (root_id == dev_id) {
recvbuffer =
......@@ -227,7 +227,7 @@ void ReduceOpHandle::RunImpl() {
int dev_id = BOOST_GET_CONST(platform::XPUPlace, p).device;
auto &bkcl_ctx = bkcl_ctxs_->at(dev_id);
void *buffer = const_cast<void *>(lod_tensor.data<void>());
void *buffer = const_cast<void *>(lod_tensor.data());
void *recvbuffer = nullptr;
if (root_id == dev_id) {
recvbuffer =
......
......@@ -146,7 +146,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
void *in_tensor_buf = const_cast<void *>(in.data());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
......@@ -175,7 +175,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
void *gather_buff = gathers[i]->data<void>();
void *gather_buff = gathers[i]->data();
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
......
......@@ -134,7 +134,7 @@ struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
// init data, data buffer
t_.data = const_cast<void *>(tensor.data<void>());
t_.data = const_cast<void *>(tensor.data());
// init device, DLDevice type with device_type and device_id
auto place = tensor.place();
......
......@@ -150,8 +150,7 @@ class AscendInstance {
VarTypeToGeType(tensor->type()));
tensor_desc.SetRealDimCnt(vec_dim.size());
const uint8_t *data =
reinterpret_cast<const uint8_t *>(tensor->data<void>());
const uint8_t *data = reinterpret_cast<const uint8_t *>(tensor->data());
std::vector<uint8_t> dst(numel * GeTypeSize(tensor->type()));
memcpy(dst.data(), data, GeTypeSize(tensor->type()) * numel);
ge::Tensor ge_tensor(tensor_desc, dst);
......
......@@ -112,20 +112,19 @@ void HeterWrapper::SerializeToReq(const std::string& varname, Scope* scope,
char* data_ptr = const_cast<char*>(req_data->data());
if (platform::is_cpu_place(tensor->place())) {
memcpy(data_ptr, tensor->data<void>(),
memcpy(data_ptr, tensor->data(),
tensor->numel() * SizeOfType(tensor->type()));
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
memory::Copy(platform::CPUPlace(), data_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * SizeOfType(tensor->type()), nullptr);
tensor->data(), tensor->numel() * SizeOfType(tensor->type()),
nullptr);
#endif
#ifdef PADDLE_WITH_XPU
memory::Copy(platform::CPUPlace(), data_ptr,
BOOST_GET_CONST(platform::XPUPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * SizeOfType(tensor->type()));
tensor->data(), tensor->numel() * SizeOfType(tensor->type()));
#endif
}
}
......
......@@ -339,7 +339,7 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request,
auto dev_id =
BOOST_GET_CONST(platform::CUDAPlace, thread_tensor->place()).device;
platform::CUDADeviceGuard guard(dev_id);
cudaMemset(thread_tensor->data<void>(), 0,
cudaMemset(thread_tensor->data(), 0,
thread_tensor->numel() * SizeOfType(thread_tensor->type()));
#endif
#ifdef PADDLE_WITH_XPU
......@@ -351,11 +351,11 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request,
platform::DeviceContext* dev_ctx = pool.Get(place);
const platform::XPUDeviceContext* xpu_ctx =
reinterpret_cast<const platform::XPUDeviceContext*>(dev_ctx);
xpu::memset(xpu_ctx->x_context(), thread_tensor->data<void>(), 0,
xpu::memset(xpu_ctx->x_context(), thread_tensor->data(), 0,
thread_tensor->numel() * SizeOfType(thread_tensor->type()));
#endif
} else {
memset(thread_tensor->data<void>(), 0,
memset(thread_tensor->data(), 0,
thread_tensor->numel() * SizeOfType(thread_tensor->type()));
}
}
......@@ -367,7 +367,7 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request,
auto dev_id =
BOOST_GET_CONST(platform::CUDAPlace, root_tensor->place()).device;
platform::CUDADeviceGuard guard(dev_id);
cudaMemset(root_tensor->data<void>(), 0,
cudaMemset(root_tensor->data(), 0,
root_tensor->numel() * SizeOfType(root_tensor->type()));
#endif
#ifdef PADDLE_WITH_XPU
......@@ -379,11 +379,11 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request,
platform::DeviceContext* dev_ctx = pool.Get(place);
const platform::XPUDeviceContext* xpu_ctx =
reinterpret_cast<const platform::XPUDeviceContext*>(dev_ctx);
xpu::memset(xpu_ctx->x_context(), root_tensor->data<void>(), 0,
xpu::memset(xpu_ctx->x_context(), root_tensor->data(), 0,
root_tensor->numel() * SizeOfType(root_tensor->type()));
#endif
} else {
memset(root_tensor->data<void>(), 0,
memset(root_tensor->data(), 0,
root_tensor->numel() * SizeOfType(root_tensor->type()));
}
}
......
......@@ -144,8 +144,8 @@ class LoDTensor : public Tensor {
*/
size_t NumLevels() const { return lod_.size(); }
/*
* Number of elements in a level.
*/
* Number of elements in a level.
*/
size_t NumElements(size_t level = 0) const {
PADDLE_ENFORCE_LT(
level, NumLevels(),
......
......@@ -71,7 +71,6 @@ ProgramDesc load_from_file(const std::string& file_name) {
fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size());
fin.close();
ProgramDesc program_desc(buffer);
return program_desc;
}
......
......@@ -788,7 +788,7 @@ void ParallelExecutor::BCastParamsToDevices(
void *buffer;
if (i == 0 && trainer_id == 0) {
buffer = const_cast<void *>(main_tensor.data<void>());
buffer = const_cast<void *>(main_tensor.data());
} else {
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
......@@ -831,7 +831,7 @@ void ParallelExecutor::BCastParamsToDevices(
void *buffer;
if (i == 0 && trainer_id == 0) {
buffer = const_cast<void *>(main_tensor.data<void>());
buffer = const_cast<void *>(main_tensor.data());
} else {
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
......
......@@ -101,20 +101,25 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE_EQ(desc_.ParseFromString(binary_str), true,
platform::errors::InvalidArgument(
"Failed to parse program_desc from binary string."));
VLOG(1) << 3333;
InitFromProto();
}
void ProgramDesc::InitFromProto() {
VLOG(1) << 4444;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
VLOG(1) << 5555;
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
VLOG(1) << 6666;
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
} else if (attr.type() == proto::AttrType::BLOCKS) {
VLOG(1) << 7777;
auto blks_idx = attr.blocks_idx();
std::vector<BlockDesc *> block_descs;
for (int blk_idx : blks_idx) {
......
......@@ -295,12 +295,12 @@ bool SaveTensorToDisk(const std::string& file_name,
// save tensor
uint64_t data_size =
tensor->numel() * framework::SizeOfType(tensor->type());
auto* data_ptr = tensor->data<void>();
auto* data_ptr = tensor->data();
if (platform::is_gpu_place(tensor->place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
data_ptr = temp.data<void>();
data_ptr = temp.data();
#else
PADDLE_THROW(platform::errors::Unavailable(
"Tensor is in CUDA device, but paddle not compiled with CUDA."));
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/pten/api/lib/utils/storage.h"
DECLARE_bool(use_stream_safe_cuda_allocator);
......@@ -26,148 +27,55 @@ class Allocation;
namespace paddle {
namespace framework {
extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(holder_, platform::errors::PreconditionNotMet(
"Tensor holds no memory. "
"Call Tensor::mutable_data firstly."));
size_t size = numel() * SizeOfType(type());
PADDLE_ENFORCE_LE(
size, memory_size(),
platform::errors::PreconditionNotMet(
"Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its "
"memory."
"But received Tensor's dimension is d%, memory's size is %d.",
size, memory_size()));
}
Tensor::Tensor(const proto::VarType::Type& dtype)
: type_(dtype),
offset_(0),
inplace_version_counter_(std::make_shared<TensorInplaceVersion>(0)) {}
size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
void* Tensor::mutable_data(const platform::Place& place,
proto::VarType::Type type, size_t requested_size) {
type_ = type;
PADDLE_ENFORCE_GE(
numel(), 0,
platform::errors::PreconditionNotMet(
"The Tensor's element number must be equal or greater than zero. "
"The Tensor's shape is [",
dims(), "] now"));
size_t size = numel() * SizeOfType(type);
if (requested_size && (requested_size > size)) {
size = requested_size;
}
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
// Reset holder first before re-allocate to save memory
holder_.reset();
holder_ = memory::AllocShared(place, size);
offset_ = 0;
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
void* Tensor::mutable_data(const platform::Place& place,
size_t requested_size) {
PADDLE_ENFORCE_NOT_NULL(this->holder_, platform::errors::PreconditionNotMet(
"The tensor is not initialized."));
return mutable_data(place, type_, requested_size);
}
void* Tensor::mutable_data(const platform::Place& place,
proto::VarType::Type type,
const platform::Stream& stream) {
type_ = type;
PADDLE_ENFORCE_GE(
numel(), 0,
platform::errors::PreconditionNotMet(
"The Tensor's element number must be equal or greater than zero. "
"The Tensor's shape is [",
dims(), "] now"));
size_t size = numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_ ||
!(platform::is_gpu_place(place) &&
memory::InSameStream(holder_, stream))) {
holder_.reset();
holder_ = memory::AllocShared(place, size, stream);
offset_ = 0;
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
Tensor& Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size();
*this = src;
return *this;
}
Tensor& Tensor::ShareInplaceVersionCounterWith(const Tensor& src) {
PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_,
platform::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_."));
inplace_version_counter_ = src.inplace_version_counter_;
return *this;
}
Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(
begin_idx, 0,
platform::errors::OutOfRange("The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(
end_idx, dims_[0],
platform::errors::OutOfRange("The end row index is out of bound."));
PADDLE_ENFORCE_GE(begin_idx, 0,
paddle::platform::errors::OutOfRange(
"The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(end_idx, meta_.dims[0],
paddle::platform::errors::OutOfRange(
"The end row index is out of bound."));
PADDLE_ENFORCE_LT(
begin_idx, end_idx,
platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument(
"The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.",
begin_idx, end_idx));
if (dims_[0] == 1) {
if (meta_.dims[0] == 1) {
return *this;
} else {
size_t base = numel() / dims_[0];
size_t base = numel() / meta_.dims[0];
Tensor dst;
dst.holder_ = holder_;
dst.set_layout(layout_);
dst.type_ = type_;
DDim dst_dims = dims_;
dst.storage_ = pten::make_intrusive<paddle::experimental::SharedStorage>(
storage_->data_shared());
dst.meta_.layout = meta_.layout;
dst.meta_.dtype = meta_.dtype;
DDim dst_dims = meta_.dims;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * SizeOfType(type());
dst.meta_.offset = meta_.offset + begin_idx * base * SizeOf(dtype());
return dst;
}
}
std::vector<Tensor> Tensor::Split(int64_t split_size, int64_t axis) const {
check_memory_size();
PADDLE_ENFORCE_GE(dims_.size(), 0,
platform::errors::OutOfRange(
PADDLE_ENFORCE_GE(meta_.dims.size(), 0,
paddle::platform::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
split_size, 0,
platform::errors::OutOfRange(
paddle::platform::errors::OutOfRange(
"split expects split_size be non-negative, but got split_size is %d",
split_size));
int64_t numel_size = dims_[axis];
int64_t numel_size = meta_.dims[axis];
int64_t num_splits = 1;
if (split_size != 0) {
......@@ -187,49 +95,33 @@ std::vector<Tensor> Tensor::Split(int64_t split_size, int64_t axis) const {
std::vector<Tensor> Tensor::Chunk(int64_t chunks, int64_t axis) const {
check_memory_size();
PADDLE_ENFORCE_GE(dims_.size(), 0,
platform::errors::OutOfRange(
PADDLE_ENFORCE_GE(meta_.dims.size(), 0,
paddle::platform::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
chunks, 0,
platform::errors::OutOfRange(
paddle::platform::errors::OutOfRange(
"chunks expects to be greater than 0, but got chunks is %d", chunks));
int64_t numel_size = dims_[axis];
int64_t numel_size = meta_.dims[axis];
int64_t split_size = (numel_size + chunks - 1) / chunks;
return Split(split_size, axis);
}
Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
Tensor& Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size();
*this = src;
return *this;
}
Tensor& Tensor::ShareInplaceVersionCounterWith(const Tensor& src) {
PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_,
platform::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_."));
const DDim& Tensor::dims() const { return dims_; }
int64_t Tensor::numel() const { return product(dims_); }
void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) {
PADDLE_ENFORCE_EQ(
offset_, 0,
platform::errors::Fatal(
"Only the offset is supported to zero when the holder is reset."));
if (holder_) {
PADDLE_ENFORCE_LE(
numel() * SizeOfType(type()) + offset_, holder->size(),
paddle::platform::errors::InvalidArgument(
"The size of Holder is not enough to store the Tensor."));
}
holder_ = holder;
}
void Tensor::ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
const proto::VarType::Type& type) {
type_ = type;
ResetHolder(holder);
inplace_version_counter_ = src.inplace_version_counter_;
return *this;
}
void Tensor::set_type(const proto::VarType::Type& type) { type_ = type; }
} // namespace framework
} // namespace paddle
......@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/stream.h"
#include "paddle/pten/core/dense_tensor.h"
namespace paddle {
namespace memory {
namespace allocation {
......@@ -75,98 +77,10 @@ class LoDTensor;
Variable object but not a pointer.
*/
class TensorInplaceVersion {
class Tensor : public pten::DenseTensor {
public:
explicit TensorInplaceVersion(uint32_t inplace_version = 0)
: inplace_version_(inplace_version) {}
bool IsUnique() const { return inplace_version_ == 0; }
void Bump() { ++inplace_version_; }
uint32_t CurrentVersion() const { return inplace_version_; }
void SetInplaceVersionToZero() { inplace_version_ = 0; }
private:
uint32_t inplace_version_;
};
class Tensor {
#ifdef PADDLE_WITH_MKLDNN
public:
inline dnnl::memory::format_tag format() const { return format_; }
inline void set_format(const dnnl::memory::format_tag format) {
format_ = format;
}
protected:
/**
* @brief the detail format of memory block which have layout as kMKLDNN
*
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
* nChw16c, etc. For a MKLDNN memory block, layout will be set as
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
* this field.
*/
dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
#endif
public:
Tensor()
: type_(proto::VarType::FP32),
offset_(0),
inplace_version_counter_(std::make_shared<TensorInplaceVersion>(0)) {}
explicit Tensor(const proto::VarType::Type&);
/*! Return a pointer to mutable memory block. */
const void* data() const;
template <typename T>
T* data();
/*! Return a pointer to constant memory block. */
template <typename T>
const T* data() const;
inline bool IsInitialized() const;
/**
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
*/
template <typename T>
T* mutable_data(const platform::Place& place, size_t requested_size = 0);
void* mutable_data(const platform::Place& place, proto::VarType::Type type,
size_t requested_size = 0);
void* mutable_data(const platform::Place& place, size_t requested_size = 0);
void* mutable_data(const platform::Place& place, proto::VarType::Type type,
const platform::Stream& stream);
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
* @param[in] requested_size The size of the block in bytes.
*
* @note If not exist, then allocation.
*/
template <typename T>
T* mutable_data(const DDim& dims, const platform::Place& place,
size_t requested_size = 0);
/*! Return the dimensions of the memory block. */
const DDim& dims() const;
/*! Return the numel of the memory block. */
int64_t numel() const;
/*! Resize the dimensions of the memory block. */
Tensor& Resize(const DDim& dims);
using DenseTensor = pten::DenseTensor;
using DenseTensor::DenseTensor;
/*! The internal of two tensors share the same memory block. */
Tensor& ShareDataWith(const Tensor& src);
......@@ -174,150 +88,16 @@ class Tensor {
/*! The internal of two tensors share the same inplace version counter. */
Tensor& ShareInplaceVersionCounterWith(const Tensor& src);
/**
* @brief Return a sub-tensor of the given tensor.
*
* @param[in] begin_idx The index of the start row(inclusive) to slice.
* The index number begins from 0.
* @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0.
*/
Tensor Slice(int64_t begin_idx, int64_t end_idx) const;
/**
* @brief Return a tensor list of the given tensor.
*
* @param[in] split_size The size of tensor to be split along axis.
* @param[in] axis The axis along which to split.
*/
std::vector<Tensor> Split(int64_t split_size, int64_t axis) const;
/**
* @brief Return a tensor list of the given tensor.
*
* @param[in] chunks The number of tensor to be split along axis.
* @param[in] axis The axis along which to split.
*/
std::vector<Tensor> Chunk(int64_t chunks, int64_t axis) const;
const platform::Place& place() const {
PADDLE_ENFORCE_NOT_NULL(
holder_,
platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::place() is called."));
return holder_->place();
Tensor& Resize(const DDim& dims) {
meta_.dims = dims;
return *this;
}
proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL(
holder_,
platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::type() is called."));
return type_;
}
/**
* [Add method get the saved type of tensor]
*
* After the introduction of complex number calculations, Ops that support
* complex number calculations generally support type promotion, such as
* x(float32) + y(complex64) = out(complex64), then the type of the grad
* tensor should be dout(complex64), dx(float32), dy (complex64), but the
* type of dx to be recognized to be float32 by the grad Op relay on the type
* of forward tensor x. But many of our ops have registered InplaceInferer,
* covering the tensor memory of x with out, so as to save storage.
*
* In this case, the dim and type information recorded by x still exist,
* but because x becomes an uninitialized tensor, The type of x record cannot
* be obtained with x.type(), but the type is still valid here, so we
* add saved_type(), This method SHOULD NOT be called by general scenarios.
*/
proto::VarType::Type saved_type() const { return type_; }
// memory size returns the holding memory size in byte.
size_t memory_size() const;
void check_memory_size() const;
DataLayout layout() const { return layout_; }
void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() {
holder_ = nullptr;
offset_ = 0;
}
void ShareBufferWith(const Tensor& tensor) {
holder_ = tensor.holder_;
offset_ = tensor.offset_;
// NOTE(chenfeiyu): when sharing buffer, by definition only holder
// to the memory allocation and offset should be shared. Shape,
// data type, layout, and other metadata associated with a Tensor
// should not be copied.
}
void ShareDataTypeWith(const Tensor& tensor) { type_ = tensor.type_; }
bool IsSharedBufferWith(const Tensor& src) const {
return holder_ && holder_ == src.Holder();
}
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; }
void set_offset(size_t offset) { offset_ = offset; }
std::shared_ptr<memory::Allocation> MoveMemoryHolder() {
return std::move(holder_);
}
void ResetHolder(std::shared_ptr<memory::Allocation> holder);
void ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
const proto::VarType::Type& type);
void set_type(const proto::VarType::Type& type);
TensorInplaceVersion& InplaceVersionCounter() {
return *inplace_version_counter_;
}
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
proto::VarType::Type type_;
/**
* @brief points to elements dimensions.
*
* @note dims_ do not indicate the memory block size.
*/
DDim dims_;
/**
* @brief the layout of memory block, default is NHWC.
*
* @note the memory allocation order, describe how weight/data is stored
* For example, in 4-D Tensor(rank=4), there are three commonly
* used layout. They are
* NCHW, NHWC, CHWN.
* N,C,H,W for respectively the batch size, the number of
* feature maps, the height.
*/
// Fix me: here just change the default layout to kNCHW
// it doesn't fix the real issue, i.e. feeder should set up tensor layout
// according to actual input data
DataLayout layout_ = DataLayout::kNCHW;
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
* @note Some of them may be slices of the others. So the offset_
* is introduced here to indicate the byte offset between
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t offset_;
std::shared_ptr<TensorInplaceVersion> inplace_version_counter_;
};
} // namespace framework
......
......@@ -20,61 +20,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
inline const T* Tensor::data() const {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType();
PADDLE_ENFORCE_EQ(
valid, true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type, it holds %s, but desires to be %s.",
DataTypeToString(type_),
DataTypeToString(DataTypeTrait<T>::DataType())));
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
inline bool Tensor::IsInitialized() const { return holder_ != nullptr; }
template <typename T>
inline T* Tensor::data() {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType();
PADDLE_ENFORCE_EQ(
valid, true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type, it holds %s, but desires to be %s",
DataTypeToString(type_),
DataTypeToString(DataTypeTrait<T>::DataType())));
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
inline const void* Tensor::data() const {
check_memory_size();
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline T* Tensor::mutable_data(const DDim& dims, const platform::Place& place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place, requested_size);
}
template <typename T>
inline T* Tensor::mutable_data(const platform::Place& place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(
mutable_data(place, DataTypeTrait<T>::DataType(), requested_size));
}
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
int rank = src.dims().size();
......
......@@ -45,7 +45,6 @@ TEST(Tensor, DataAssert) {
} catch (platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(ex_msg.find("holder_ should not be null") != std::string::npos);
EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call "
"Tensor::mutable_data firstly.") !=
std::string::npos);
......@@ -189,8 +188,6 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(ex_msg.find("holder_ should not be null") !=
std::string::npos);
EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call "
"Tensor::mutable_data firstly.") !=
std::string::npos);
......
......@@ -45,7 +45,6 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place();
......@@ -442,6 +441,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto src_place = src.place();
auto src_ptr = src.data();
auto dst_ptr = dst->mutable_data(dst_place, src.type());
VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr;
if (src_ptr == dst_ptr && src_place == dst_place) {
VLOG(3) << "Skip copy the same data from " << src_place << " to "
......
......@@ -72,7 +72,7 @@ class Variable {
private:
// This method hides type T, so it doesn't appear as a template parameter of
// Variable.
framework::TensorInplaceVersion* InplaceVersionCounter();
pten::TensorInplaceVersion* InplaceVersionCounter();
public:
void SetInplaceVersionToZero();
......@@ -114,8 +114,8 @@ class Variable {
std::shared_ptr<Placeholder> holder_;
};
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
framework::TensorInplaceVersion* version_counter_ptr(nullptr);
inline pten::TensorInplaceVersion* Variable::InplaceVersionCounter() {
pten::TensorInplaceVersion* version_counter_ptr(nullptr);
if (IsType<framework::LoDTensor>()) {
version_counter_ptr =
&GetMutable<framework::LoDTensor>()->InplaceVersionCounter();
......
......@@ -60,7 +60,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
const void *src_ptr = src.data<void>();
const void *src_ptr = src.data();
dst->Resize(src.dims());
auto *dst_ptr = dst->mutable_data(src.place(), src.type());
auto nccl_dtype = platform::ToNCCLDataType(src.type());
......@@ -129,7 +129,7 @@ static void AllReduce(const framework::SelectedRows &src,
auto feature_size = framework::product(dims) / dims[0];
dst_tensor->Resize(dims);
auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype);
const auto *src_tensor_ptr = src_tensor.data<void>();
const auto *src_tensor_ptr = src_tensor.data();
auto sizeof_dtype = framework::SizeOfType(dtype);
int64_t row_offset = 0;
......
......@@ -39,7 +39,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
platform::errors::Unimplemented(
"Dynamic graph mode does not support multi-CPU training yet."));
const void *src_ptr = src.data<void>();
const void *src_ptr = src.data();
dst->Resize(src.dims());
auto *dst_ptr = dst->mutable_data(src.place(), src.type());
auto bkcl_dtype = platform::ToBKCLDataType(src.type());
......@@ -158,7 +158,7 @@ void BKCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
platform::BKCLCommContext::Instance().Get(ring_id, place);
XPUStream stream = comm->stream();
void *src_ptr = src_tensor->data<void>();
void *src_ptr = src_tensor->data();
auto data_type = platform::ToBKCLDataType(src_tensor->type());
PADDLE_ENFORCE_EQ(bkcl_broadcast(comm->comm(), src_ptr, src_ptr,
......
......@@ -42,7 +42,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
void *src_ptr = const_cast<void *>(src.data<void>());
void *src_ptr = const_cast<void *>(src.data());
dst->Resize(src.dims());
void *dst_ptr = dst->mutable_data(src.place(), src.type());
HcclDataType hccl_dtype = platform::ToHCCLDataType(src.type());
......@@ -168,7 +168,7 @@ void HCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
aclrtStream stream = comm->stream();
void *src_ptr =
reinterpret_cast<void *>(const_cast<void *>(src_tensor->data<void>()));
reinterpret_cast<void *>(const_cast<void *>(src_tensor->data()));
auto hccl_dtype = platform::ToHCCLDataType(src_tensor->type());
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(
src_ptr, src_tensor->numel(), hccl_dtype, 0, comm->comm(),
......
......@@ -143,7 +143,7 @@ void NCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = comm->stream();
void *src_ptr = src_tensor->data<void>();
void *src_ptr = src_tensor->data();
auto nccl_dtype = platform::ToNCCLDataType(src_tensor->type());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
src_ptr, src_tensor->numel(), nccl_dtype, 0, comm->comm(), stream));
......
......@@ -176,8 +176,8 @@ static bool IsEqualVar(const framework::Variable& var1,
return false;
}
auto* t1_p = t1.data<void>();
auto* t2_p = t2.data<void>();
auto* t1_p = t1.data();
auto* t2_p = t2.data();
return std::memcmp(t1_p, t2_p,
t1.numel() * framework::SizeOfType(t1.type())) == 0;
}
......
......@@ -37,13 +37,13 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt;
if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else if (t->type() == framework::proto::VarType::INT32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int32_t));
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -210,7 +210,7 @@ void TensorCopyAsync(paddle::lite_api::Tensor* dst,
const size_t bytes =
static_cast<size_t>(src.numel()) * framework::SizeOfType(src.type());
dst->Resize(framework::vectorize(src.dims()));
const void* src_data = src.data<void>();
const void* src_data = src.data();
void* dst_data{nullptr};
dst_data = GetLiteTensorDataPtr(dst, GetLitePrecisionType(src.type()),
GetLiteTargetType(src.place()));
......@@ -242,7 +242,7 @@ void TensorCopyAsync(framework::LoDTensor* dst,
template <>
void TensorDataShare(paddle::lite_api::Tensor* dst, framework::LoDTensor* src) {
dst->Resize(framework::vectorize(src->dims()));
dst->ShareExternalMemory(src->data<void>(), src->memory_size(),
dst->ShareExternalMemory(src->data(), src->memory_size(),
GetLiteTargetType(src->place()));
dst->SetPrecision(GetLitePrecisionType(src->type()));
paddle::lite::LoD lite_lod;
......
......@@ -176,7 +176,7 @@ class LazyZerosNPU {
NpuOpRunner("ZerosLike", {*zero_tensor}, {*zero_tensor});
runner_zeros.Run(stream);
zero_tensor->check_memory_size();
zero_ptr = zero_tensor->data<void>();
zero_ptr = zero_tensor->data();
}
for (size_t i = 0; i < xs.size(); ++i) {
......
......@@ -260,8 +260,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
size_of_dtype
: len;
ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")"
<< " address: " << out_tensors[i]->data<void>() << " len: " << len
<< ", ";
<< " address: " << out_tensors[i]->data() << " len: " << len << ", ";
offset += len;
}
PADDLE_ENFORCE_EQ(
......@@ -300,9 +299,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
place, align_size) /
size_of_dtype
: static_cast<size_t>(size);
const void *ptr = lod_tensors[i]->IsInitialized()
? lod_tensors[i]->data<void>()
: nullptr;
const void *ptr =
lod_tensors[i]->IsInitialized() ? lod_tensors[i]->data() : nullptr;
VLOG(4) << size << " " << len;
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
<< ") "
......
......@@ -43,7 +43,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
int dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
auto* sendbuff = in->data<void>();
auto* sendbuff = in->data();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
......
......@@ -33,7 +33,7 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data();
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
......
......@@ -46,7 +46,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
"because this op can only be an In-Place operation."));
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
send_recv_buffer, in->data<void>(),
send_recv_buffer, in->data(),
platform::errors::PreconditionNotMet("Currently, the broadcast op can "
"only be an In-Place operation."));
......
......@@ -52,7 +52,7 @@ class BKCLBroadcastOpKernel : public framework::OpKernel<T> {
"because this op can only be an In-Place operation."));
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
send_recv_buffer, in->data<void>(),
send_recv_buffer, in->data(),
platform::errors::PreconditionNotMet("Currently, the broadcast op can "
"only be an In-Place operation."));
......
......@@ -213,7 +213,7 @@ class CReduceOpXPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
BKCLDataType dtype = platform::ToBKCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
......@@ -276,7 +276,7 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
......
......@@ -144,8 +144,8 @@ void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) {
static void AppendProposals(framework::Tensor* dst, int64_t offset,
const framework::Tensor& src) {
auto* out_data = dst->data<void>();
auto* to_add_data = src.data<void>();
auto* out_data = dst->data();
auto* to_add_data = src.data();
size_t size_of_t = framework::SizeOfType(src.type());
offset *= size_of_t;
std::memcpy(
......
......@@ -64,8 +64,8 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data<void>());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data<void>());
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data());
framework::proto::VarType::Type x_dtype = x->type();
framework::proto::VarType::Type scale_bias_dtype;
......
......@@ -48,7 +48,7 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> {
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_gpu_mat_data->ptr(),
boost::get<platform::CUDAPlace>(context.GetPlace()),
a.data<void>(), a.numel() * sizeof(T), context.stream());
a.data(), a.numel() * sizeof(T), context.stream());
gpu_mat = reinterpret_cast<const T*>(tmp_gpu_mat_data->ptr());
}
......
......@@ -492,9 +492,9 @@ class LambOpKernel : public framework::OpKernel<T> {
auto trust_ratio_div =
ctx.AllocateTmpTensor<MT, DeviceContext>(param.dims(), dev_ctx);
const void* param_ptr = param.template data<void>();
const void* param_ptr = param.data();
const void* master_param_ptr =
master_param ? master_param->template data<void>() : nullptr;
master_param ? master_param->data() : nullptr;
void* param_out_ptr = param_out.template mutable_data<T>(ctx.GetPlace());
void* master_param_out_ptr =
master_param_out
......
......@@ -132,7 +132,7 @@ void BufferedReader::ReadAsync(size_t i) {
memory::Copy(cuda_pinned_place, cuda_pinned_ptrs[i],
BOOST_GET_CONST(platform::CPUPlace, cpu[i].place()),
cpu[i].data<void>(), size);
cpu[i].data(), size);
cuda[i].set_lod(cpu[i].lod());
} else {
......@@ -175,7 +175,7 @@ void BufferedReader::ReadAsync(size_t i) {
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
auto cpu_place = cpu[i].place();
auto cpu_ptr = cpu[i].data<void>();
auto cpu_ptr = cpu[i].data();
auto gpu_ptr = gpu_ptrs[i];
auto size =
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
......@@ -239,7 +239,7 @@ void BufferedReader::ReadAsync(size_t i) {
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
auto cpu_place = cpu[i].place();
auto cpu_ptr = cpu[i].data<void>();
auto cpu_ptr = cpu[i].data();
auto npu_ptr = npu_ptrs[i];
auto size =
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
......
......@@ -587,15 +587,13 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data<void>(),
collapsed_output.data<void>()));
desc.get(), collapsed_input_conj.data(), collapsed_output.data()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.type());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output_conj.data<void>()));
desc.get(), collapsed_input.data(), collapsed_output_conj.data()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
......@@ -605,12 +603,10 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
} else {
if (forward) {
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
desc.get(), collapsed_input.data(), collapsed_output.data()));
} else {
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
desc.get(), collapsed_input.data(), collapsed_output.data()));
}
}
......
......@@ -115,22 +115,19 @@ void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_cufft_plan_raw(config, input_conj.data<void>(), output->data<void>(),
forward);
exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_cufft_plan_raw(config, input->data<void>(), out_conj.data<void>(),
forward);
exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_cufft_plan_raw(config, input->data<void>(), output->data<void>(),
forward);
exec_cufft_plan_raw(config, input->data(), output->data(), forward);
}
}
......@@ -227,22 +224,19 @@ void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_hipfft_plan_raw(config, input_conj.data<void>(), output->data<void>(),
forward);
exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_hipfft_plan_raw(config, input->data<void>(), out_conj.data<void>(),
forward);
exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_hipfft_plan_raw(config, input->data<void>(), output->data<void>(),
forward);
exec_hipfft_plan_raw(config, input->data(), output->data(), forward);
}
}
......
......@@ -405,7 +405,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (param_names_.count(x)) continue;
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
calib_data.emplace(x, t.data<void>());
calib_data.emplace(x, t.data());
}
temp_calibrator->setBatch(calib_data);
RunNativeImpl(scope, dev_place);
......
......@@ -271,7 +271,7 @@ void Compiler::LowerWeights(const framework::ir::Graph* graph,
shape.push_back(tensor.dims().at(i));
}
popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data<void>(), tensor_info};
popart::ConstVoidData const_data{tensor.data(), tensor_info};
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
tensors_.emplace(var_name, result);
......
......@@ -18,7 +18,7 @@ namespace paddle {
namespace platform {
namespace ipu {
void* PaddleIArray::data() { return tensor_->data<void>(); }
void* PaddleIArray::data() { return tensor_->data(); }
popart::DataType PaddleIArray::dataType() const {
return VarType2PopartType(tensor_->type());
......
......@@ -84,7 +84,7 @@ std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(
popart::TensorInfo tensor_info(dtype, shape);
return std::make_unique<popart::NDArrayWrapper<T>>(
reinterpret_cast<T *>(tensor.data<void>()), tensor_info);
reinterpret_cast<T *>(tensor.data()), tensor_info);
}
template <typename T>
......
......@@ -401,7 +401,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor,
}
aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) {
void *ptr = tensor.data<void>();
void *ptr = tensor.data();
VLOG(4) << "NPU ptr: " << ptr << ", size: " << tensor.memory_size();
auto *buffer = aclCreateDataBuffer(ptr, tensor.memory_size());
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -150,8 +150,8 @@ void FillNpuTensorWithConstant(Tensor *tensor, T val) {
*npu_pinned_ptr = val;
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, tensor->place()),
tensor->data<void>(), npu_pinned_place, npu_pinned_ptr,
sizeof(T), GetCurrentNPUStream());
tensor->data(), npu_pinned_place, npu_pinned_ptr, sizeof(T),
GetCurrentNPUStream());
auto npu_pinned_allocator =
static_cast<paddle::memory::allocation::NPUPinnedAllocator *>(
......
......@@ -792,7 +792,7 @@ void BindImperative(py::module *m_ptr) {
SetTensorFromPyArray<platform::CPUPlace>(&t, array,
platform::CPUPlace(), true);
// 3. allocate shared memory
void *data_ptr = t.data<void>();
void *data_ptr = t.data();
size_t data_size = t.numel() * framework::SizeOfType(t.type());
auto shared_writer_holder =
memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
......@@ -827,7 +827,7 @@ void BindImperative(py::module *m_ptr) {
SetTensorFromPyArray<platform::CPUPlace>(&t, array,
platform::CPUPlace(), true);
// 3. allocate shared memory
void *data_ptr = t.data<void>();
void *data_ptr = t.data();
size_t data_size = t.numel() * framework::SizeOfType(t.type());
auto shared_writer_holder =
memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
......@@ -1857,7 +1857,7 @@ void BindImperative(py::module *m_ptr) {
// 1. get LoDTensor
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
// 2. allocate shared memory
void *data_ptr = t->data<void>();
void *data_ptr = t->data();
size_t data_size = t->numel() * framework::SizeOfType(t->type());
auto shared_writer_holder =
memory::allocation::AllocateMemoryMapWriterAllocation(
......
......@@ -729,7 +729,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor,
numel *= py_dims[i];
}
const void *tensor_buf_ptr = tensor.data<void>();
const void *tensor_buf_ptr = tensor.data();
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type());
......
......@@ -83,8 +83,21 @@ class SharedStorage : public pten::Storage {
size_ = 0;
}
size_t size() const noexcept override { return size_; }
const paddle::platform::Place& place() const override { return place_; }
void set_data_shared(
const std::shared_ptr<paddle::memory::Allocation>& holder) override {
data_ = holder;
if (holder) {
size_ = holder->size();
place_ = holder->place();
}
}
size_t size() const noexcept override {
return data_ ? data_->size() : size_;
}
const paddle::platform::Place& place() const override {
return data_ ? data_->place() : place_;
}
bool OwnsMemory() const noexcept override { return false; }
const std::shared_ptr<paddle::memory::Allocation>& GetAllocation() {
......
......@@ -41,12 +41,32 @@ DenseTensor::DenseTensor(intrusive_ptr<Storage> storage,
DenseTensor::DenseTensor(intrusive_ptr<Storage> storage, DenseTensorMeta&& meta)
: meta_(std::move(meta)), storage_(std::move(storage)) {}
DenseTensor::DenseTensor(const DenseTensor& other)
: meta_(other.meta()), storage_(copy_intrusive(other.storage_)) {}
DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
if (storage_ == nullptr) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
}
if (other.storage_ != nullptr && other.storage_->data_shared()) {
storage_->set_data_shared(other.storage_->data_shared());
}
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
#endif
}
DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
meta_ = other.meta();
storage_ = std::move(copy_intrusive(other.storage_));
if (storage_ == nullptr) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
}
if (other.storage_ != nullptr && other.storage_->data_shared()) {
storage_->set_data_shared(other.storage_->data_shared());
}
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
#endif
return *this;
}
......@@ -138,22 +158,22 @@ T* DenseTensor::data() {
return reinterpret_cast<T*>(data());
}
const void* DenseTensor::data() const {
void* DenseTensor::data() {
PADDLE_ENFORCE_NOT_NULL(
storage_,
paddle::platform::errors::PreconditionNotMet(
"The storage must be valid when call the mutable data function."));
return reinterpret_cast<const void*>(
reinterpret_cast<uintptr_t>(storage_->data()) + meta_.offset);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(storage_->data()) +
meta_.offset);
}
void* DenseTensor::data() {
const void* DenseTensor::data() const {
PADDLE_ENFORCE_NOT_NULL(
storage_,
paddle::platform::errors::PreconditionNotMet(
"The storage must be valid when call the mutable data function."));
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(storage_->data()) +
meta_.offset);
return reinterpret_cast<const void*>(
reinterpret_cast<uintptr_t>(storage_->data()) + meta_.offset);
}
void DenseTensor::set_meta(DenseTensorMeta&& meta) {
......@@ -174,12 +194,11 @@ void DenseTensor::set_meta(DenseTensorMeta&& meta) {
storage_ won't be initialized until the first
call to mutable_data(place)
*/
DenseTensor& DenseTensor::Resize(const DDim& dims) {
void DenseTensor::Resize(const DDim& dims) {
meta_.dims = dims;
if (storage_ != nullptr) {
mutable_data();
}
return *this;
}
void DenseTensor::ResetLoD(const LoD& lod) { meta_.lod = lod; }
......@@ -211,36 +230,21 @@ DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128);
/* From framework::Tensor */
/* --------------------------- */
DenseTensor::DenseTensor() {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
inplace_version_counter_ = std::make_shared<TensorInplaceVersion>(0);
meta_ = DenseTensorMeta();
meta_.dtype = paddle::experimental::DataType::FLOAT32;
meta_.offset = 0;
}
DenseTensor::DenseTensor(const paddle::framework::proto::VarType::Type& dtype) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
inplace_version_counter_ = std::make_shared<TensorInplaceVersion>(0);
meta_ = DenseTensorMeta();
meta_.dtype = TransToPtenDataType(dtype);
meta_.offset = 0;
}
DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
src.check_memory_size();
*this = src;
return *this;
}
DenseTensor& DenseTensor::ShareInplaceVersionCounterWith(
const DenseTensor& src) {
PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_,
paddle::platform::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_."));
inplace_version_counter_ = src.inplace_version_counter_;
return *this;
}
size_t DenseTensor::memory_size() const {
if (storage_ == nullptr || storage_->data_shared() == nullptr) {
return 0UL;
......@@ -304,16 +308,15 @@ void DenseTensor::ResetHolder(
paddle::platform::errors::Fatal(
"Only the offset is supported to zero when the holder is reset."));
if (storage_ == nullptr) {
PADDLE_THROW(
paddle::platform::errors::Fatal("storage_ has to be initialized before "
"calling ResetHolder() interface."));
}
PADDLE_ENFORCE_NOT_NULL(
storage_,
paddle::platform::errors::PreconditionNotMet(
"The storage must be valid when call the mutable data function."));
if (storage_->data_shared()) {
PADDLE_ENFORCE_LE(
numel() * SizeOf(dtype()) + meta_.offset,
storage_->data_shared()->size(),
holder->size(),
paddle::platform::errors::InvalidArgument(
"The size of Holder is not enough to store the Tensor."));
}
......@@ -333,95 +336,6 @@ void DenseTensor::set_type(
meta_.dtype = TransToPtenDataType(type);
}
DenseTensor DenseTensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx,
0,
paddle::platform::errors::OutOfRange(
"The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(end_idx,
meta_.dims[0],
paddle::platform::errors::OutOfRange(
"The end row index is out of bound."));
PADDLE_ENFORCE_LT(
begin_idx,
end_idx,
paddle::platform::errors::InvalidArgument(
"The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.",
begin_idx,
end_idx));
if (meta_.dims[0] == 1) {
return *this;
} else {
size_t base = numel() / meta_.dims[0];
DenseTensor dst;
dst.storage_ = std::move(copy_intrusive(storage_));
dst.meta_.layout = meta_.layout;
dst.meta_.dtype = meta_.dtype;
DDim dst_dims = meta_.dims;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.meta_.offset = meta_.offset + begin_idx * base * SizeOf(dtype());
return dst;
}
}
std::vector<DenseTensor> DenseTensor::Split(int64_t split_size,
int64_t axis) const {
check_memory_size();
PADDLE_ENFORCE_GE(meta_.dims.size(),
0,
paddle::platform::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
split_size,
0,
paddle::platform::errors::OutOfRange(
"split expects split_size be non-negative, but got split_size is %d",
split_size));
int64_t numel_size = meta_.dims[axis];
int64_t num_splits = 1;
if (split_size != 0) {
num_splits =
std::max<int64_t>((numel_size + split_size - 1) / split_size, 1);
}
std::vector<DenseTensor> splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - numel_size);
for (int64_t i = 0; i < num_splits; ++i) {
int64_t length = i < num_splits - 1 ? split_size : last_split_size;
splits[i] = Slice(i * split_size, i * split_size + length);
}
return splits;
}
std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
int64_t axis) const {
check_memory_size();
PADDLE_ENFORCE_GE(meta_.dims.size(),
0,
paddle::platform::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
chunks,
0,
paddle::platform::errors::OutOfRange(
"chunks expects to be greater than 0, but got chunks is %d", chunks));
int64_t numel_size = meta_.dims[axis];
int64_t split_size = (numel_size + chunks - 1) / chunks;
return Split(split_size, axis);
}
void* DenseTensor::mutable_data(const paddle::platform::Place& place,
paddle::framework::proto::VarType::Type type,
size_t requested_size) {
......@@ -447,23 +361,16 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
if (storage_->data_shared() == nullptr ||
!(storage_->data_shared()->place() == place) ||
storage_->data_shared()->size() < size + meta_.offset) {
// Reset holder first before re-allocate to save memory
storage_->Clear();
storage_->set_data_shared(paddle::memory::AllocShared(place, size));
meta_.offset = 0;
}
return reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(storage_->data_shared()->ptr()) +
meta_.offset);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(storage_->data()) +
meta_.offset);
}
void* DenseTensor::mutable_data(const paddle::platform::Place& place,
size_t requested_size) {
if (storage_ == nullptr) {
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"The tensor is not initialized."));
}
return mutable_data(place, type(), requested_size);
}
......@@ -481,8 +388,12 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
"] now"));
size_t size = numel() * SizeOf(dtype());
if (storage_ == nullptr) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(place);
}
/* some versions of boost::variant don't have operator!= */
if (storage_ == nullptr || storage_->data_shared() == nullptr ||
if (storage_->data_shared() == nullptr ||
!(storage_->data_shared()->place() == place) ||
storage_->data_shared()->size() < size + meta_.offset ||
!(paddle::platform::is_gpu_place(place) &&
......@@ -491,9 +402,8 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
storage_->set_data_shared(paddle::memory::AllocShared(place, size, stream));
meta_.offset = 0;
}
return reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(storage_->data_shared()->ptr()) +
meta_.offset);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(storage_->data()) +
meta_.offset);
}
/* @jim19930609: The following "mutable_data" only supports specific dtypes
......@@ -506,7 +416,7 @@ inline T* DenseTensor::mutable_data(const DDim& dims,
const paddle::platform::Place& place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
meta_.dims = dims;
return mutable_data<T>(place, requested_size);
}
......@@ -518,6 +428,13 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
place, paddle::framework::DataTypeTrait<T>::DataType(), requested_size));
}
void DenseTensor::ShareBufferWith(const DenseTensor& tensor) {
if (storage_ != nullptr && tensor.storage_ != nullptr) {
storage_->set_data_shared(tensor.storage_->data_shared());
}
meta_.offset = tensor.meta().offset;
}
#define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DenseTensor::mutable_data( \
const DDim& dims, \
......
......@@ -157,7 +157,7 @@ class DenseTensor : public TensorBase,
/// \param dims The new dims of the dense tensor.
/// \param lod The new lod of the dense tensor.
// void Resize(const DDim& dims);
DenseTensor& Resize(const DDim& dims);
void Resize(const DDim& dims);
/// \brief Change the lod information in the metadata.
/// \param lod The new lod of the dense tensor.
......@@ -204,7 +204,7 @@ class DenseTensor : public TensorBase,
private:
friend class CompatibleDenseTensorUtils;
private:
protected:
DenseTensorMeta meta_;
intrusive_ptr<Storage> storage_;
......@@ -228,7 +228,7 @@ class DenseTensor : public TensorBase,
explicit DenseTensor(const paddle::framework::proto::VarType::Type& dtype);
inline bool IsInitialized() const {
return storage_ != nullptr && storage_->data() != nullptr;
return storage_ != nullptr && storage_->data_shared() != nullptr;
}
template <typename T>
......@@ -256,18 +256,6 @@ class DenseTensor : public TensorBase,
paddle::framework::proto::VarType::Type type,
const paddle::platform::Stream& stream);
/*! The internal of two tensors share the same memory block. */
DenseTensor& ShareDataWith(const DenseTensor& src);
/*! The internal of two tensors share the same inplace version counter. */
DenseTensor& ShareInplaceVersionCounterWith(const DenseTensor& src);
DenseTensor Slice(int64_t begin_idx, int64_t end_idx) const;
std::vector<DenseTensor> Split(int64_t split_size, int64_t axis) const;
std::vector<DenseTensor> Chunk(int64_t chunks, int64_t axis) const;
/* @jim19930609: Remove dependency on protobuf after Tensor Unification.
*/
paddle::framework::proto::VarType::Type type() const;
......@@ -288,17 +276,17 @@ class DenseTensor : public TensorBase,
meta_.offset = 0;
}
void ShareBufferWith(const DenseTensor& tensor) {
storage_ = std::move(copy_intrusive(tensor.storage_));
meta_.offset = tensor.meta().offset;
}
void ShareBufferWith(const DenseTensor& tensor);
void ShareDataTypeWith(const DenseTensor& tensor) {
meta_.dtype = tensor.meta().dtype;
}
bool IsSharedBufferWith(const DenseTensor& src) const {
return IsSharedWith(src);
if (storage_ == nullptr || src.storage_ == nullptr) return false;
if (storage_->data_shared() == src.storage_->data_shared()) return true;
return false;
}
const std::shared_ptr<paddle::memory::Allocation> Holder() const {
......@@ -325,7 +313,7 @@ class DenseTensor : public TensorBase,
return *inplace_version_counter_;
}
private:
protected:
std::shared_ptr<TensorInplaceVersion> inplace_version_counter_;
/* @jim19930609: This is a hack
......@@ -365,6 +353,7 @@ class DenseTensor : public TensorBase,
Will be adjusted/removed/moved in the near future
*/
public:
explicit DenseTensor(const LoD& lod);
void set_lod(const LoD& lod);
......
......@@ -60,7 +60,7 @@ class Storage : public intrusive_ref_counter<Storage> {
return data_;
}
void set_data_shared(
virtual void set_data_shared(
const std::shared_ptr<paddle::memory::Allocation>& holder) {
data_ = holder;
}
......
......@@ -40,8 +40,8 @@ class intrusive_ptr {
rhs.reset();
}
intrusive_ptr<T>& operator=(intrusive_ptr<T>&& rhs) {
px = std::move(rhs.px);
intrusive_ptr& operator=(intrusive_ptr&& rhs) {
swap(rhs);
return *this;
}
......
......@@ -133,7 +133,10 @@ TEST(dense_tensor, shallow_copy) {
DenseTensor tensor_1(tensor_0);
CHECK(tensor_0.meta() == tensor_1.meta());
CHECK(tensor_0.release() == tensor_1.release());
// Copy constructor: Now shares the underlying shared_ptr<Allocation> instead
// of Storage
CHECK(tensor_0.release() != tensor_1.release());
}
} // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册