未验证 提交 acbf9974 编写于 作者: A Aurelius84 提交者: GitHub

Fix lod in fetch_v2 (#37514)

上级 d5c51e62
...@@ -134,10 +134,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -134,10 +134,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) { const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type(); auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_; auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpreter::kMemcpyH2D) { if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_"; VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_); dev_ctx = d2h_ctx_pool_.Get(place_);
} else if (op_type == interpreter::kMemcpyD2H) { } else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_"; VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_); dev_ctx = h2d_ctx_pool_.Get(place_);
} }
......
...@@ -157,6 +157,7 @@ class FetchV2Kernel { ...@@ -157,6 +157,7 @@ class FetchV2Kernel {
DeepCopy(src_item, fetch_var_name, dst_item); DeepCopy(src_item, fetch_var_name, dst_item);
} else { } else {
dst_item->ShareDataWith(src_item); dst_item->ShareDataWith(src_item);
dst_item->set_lod(src_item.lod());
} }
} else { } else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>(); auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
...@@ -172,6 +173,7 @@ class FetchV2Kernel { ...@@ -172,6 +173,7 @@ class FetchV2Kernel {
DeepCopy(src_item[i], fetch_var_name, &dst_item[i]); DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
} else { } else {
dst_item[i].ShareDataWith(src_item[i]); dst_item[i].ShareDataWith(src_item[i]);
dst_item[i].set_lod(src_item[i].lod());
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册