From a0678eb1499f1dce8c19d6a3f4898989f89ddbfd Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 29 Nov 2021 13:56:02 +0800 Subject: [PATCH] Support fetch lodtensor array (#37580) * suport fetch lodtensor array, test=develop * refine, test=develop * refine, test=develop * refine, test=develop --- .../framework/new_executor/data_transfer.cc | 16 +++++++--- paddle/fluid/operators/memcpy_d2h_op.h | 31 +++++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index ca3647e7d3..bb007cd6ae 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -258,11 +258,19 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, for (auto& var_name_item : *ins_map_temp) { for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto var = var_name_item.second[i]; - if (!(var->IsType() || var->IsType())) { - continue; - } auto& var_name = new_ins[var_name_item.first].at(i); - auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); + const Tensor* tensor_in; + if (var->IsType() || var->IsType()) { + tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); + } else if (var->IsType()) { + tensor_in = + static_cast(&(var->Get()[0])); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Variable type is %s, expect LoDTensor or SelectedRows or " + "LoDTensorArray.", + ToTypeName(var->Type()))); + } if (!tensor_in->IsInitialized()) { continue; } diff --git a/paddle/fluid/operators/memcpy_d2h_op.h b/paddle/fluid/operators/memcpy_d2h_op.h index 6f9890d332..eefefea77b 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.h +++ b/paddle/fluid/operators/memcpy_d2h_op.h @@ -41,17 +41,17 @@ class MemcpyD2HFunctor { void operator()(const framework::LoDTensor &lod_tensor) const { auto &out_tensor = *out_->GetMutable(); + CopyLoDTensor(lod_tensor, out_tensor); + } - if (dst_place_type_ == 1) { - framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_, - &out_tensor); - } else if (dst_place_type_ == 0) { - framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &out_tensor); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "memcpy dst_place_type: %d is not supported yet.", dst_place_type_)); + void operator()(const framework::LoDTensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.clear(); + out_array.resize(array.size()); + + for (size_t i = 0; i < array.size(); i++) { + CopyLoDTensor(array[i], out_array[i]); } - out_tensor.set_lod(lod_tensor.lod()); } void operator()(const framework::SelectedRows &rows) const { @@ -69,6 +69,19 @@ class MemcpyD2HFunctor { } private: + void CopyLoDTensor(const framework::LoDTensor &src, + framework::LoDTensor &dst) const { // NOLINT + if (dst_place_type_ == 1) { + framework::TensorCopy(src, platform::CUDAPinnedPlace(), dev_ctx_, &dst); + } else if (dst_place_type_ == 0) { + framework::TensorCopySync(src, platform::CPUPlace(), &dst); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "memcpy dst_place_type: %d is not supported yet.", dst_place_type_)); + } + dst.set_lod(src.lod()); + } + framework::Variable *out_; const platform::DeviceContext &dev_ctx_; const int dst_place_type_; -- GitLab