未验证 提交 acbf9974 编写于 作者: A Aurelius84 提交者: GitHub

Fix lod in fetch_v2 (#37514)

上级 d5c51e62
......@@ -134,10 +134,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpreter::kMemcpyH2D) {
if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_);
} else if (op_type == interpreter::kMemcpyD2H) {
} else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_);
}
......
......@@ -157,6 +157,7 @@ class FetchV2Kernel {
DeepCopy(src_item, fetch_var_name, dst_item);
} else {
dst_item->ShareDataWith(src_item);
dst_item->set_lod(src_item.lod());
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
......@@ -172,6 +173,7 @@ class FetchV2Kernel {
DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
} else {
dst_item[i].ShareDataWith(src_item[i]);
dst_item[i].set_lod(src_item[i].lod());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册