未验证 提交 a0678eb1 编写于 作者: W wanghuancoder 提交者: GitHub

Support fetch lodtensor array (#37580)

* suport fetch lodtensor array, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop
上级 74ca89ef
...@@ -258,11 +258,19 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -258,11 +258,19 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
for (auto& var_name_item : *ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
if (!(var->IsType<LoDTensor>() || var->IsType<SelectedRows>())) {
continue;
}
auto& var_name = new_ins[var_name_item.first].at(i); auto& var_name = new_ins[var_name_item.first].at(i);
auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); const Tensor* tensor_in;
if (var->IsType<LoDTensor>() || var->IsType<SelectedRows>()) {
tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
} else if (var->IsType<LoDTensorArray>()) {
tensor_in =
static_cast<const Tensor*>(&(var->Get<LoDTensorArray>()[0]));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows or "
"LoDTensorArray.",
ToTypeName(var->Type())));
}
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
} }
......
...@@ -41,17 +41,17 @@ class MemcpyD2HFunctor { ...@@ -41,17 +41,17 @@ class MemcpyD2HFunctor {
void operator()(const framework::LoDTensor &lod_tensor) const { void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>(); auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
CopyLoDTensor(lod_tensor, out_tensor);
}
if (dst_place_type_ == 1) { void operator()(const framework::LoDTensorArray &array) const {
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_, auto &out_array = *out_->GetMutable<framework::LoDTensorArray>();
&out_tensor); out_array.clear();
} else if (dst_place_type_ == 0) { out_array.resize(array.size());
framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &out_tensor);
} else { for (size_t i = 0; i < array.size(); i++) {
PADDLE_THROW(platform::errors::Unimplemented( CopyLoDTensor(array[i], out_array[i]);
"memcpy dst_place_type: %d is not supported yet.", dst_place_type_));
} }
out_tensor.set_lod(lod_tensor.lod());
} }
void operator()(const framework::SelectedRows &rows) const { void operator()(const framework::SelectedRows &rows) const {
...@@ -69,6 +69,19 @@ class MemcpyD2HFunctor { ...@@ -69,6 +69,19 @@ class MemcpyD2HFunctor {
} }
private: private:
void CopyLoDTensor(const framework::LoDTensor &src,
framework::LoDTensor &dst) const { // NOLINT
if (dst_place_type_ == 1) {
framework::TensorCopy(src, platform::CUDAPinnedPlace(), dev_ctx_, &dst);
} else if (dst_place_type_ == 0) {
framework::TensorCopySync(src, platform::CPUPlace(), &dst);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"memcpy dst_place_type: %d is not supported yet.", dst_place_type_));
}
dst.set_lod(src.lod());
}
framework::Variable *out_; framework::Variable *out_;
const platform::DeviceContext &dev_ctx_; const platform::DeviceContext &dev_ctx_;
const int dst_place_type_; const int dst_place_type_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册