From acbf997402b283f8e4b0f87ae69c1c064eedb889 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 24 Nov 2021 20:00:20 +0800 Subject: [PATCH] Fix lod in fetch_v2 (#37514) --- paddle/fluid/framework/new_executor/stream_analyzer.cc | 4 ++-- paddle/fluid/operators/controlflow/fetch_v2_op.cc | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 3ac30456a2d..fdcd19b0309 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -134,10 +134,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( const OpFuncNode& op_func_node) { auto& op_type = op_func_node.operator_base_->Type(); auto* dev_ctx = op_func_node.dev_ctx_; - if (op_type == interpreter::kMemcpyH2D) { + if (op_type == interpreter::kMemcpyD2H) { VLOG(3) << "Get dev_ctx from d2h_context_pool_"; dev_ctx = d2h_ctx_pool_.Get(place_); - } else if (op_type == interpreter::kMemcpyD2H) { + } else if (op_type == interpreter::kMemcpyH2D) { VLOG(3) << "Get dev_ctx from h2d_context_pool_"; dev_ctx = h2d_ctx_pool_.Get(place_); } diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 93035dddefe..29132f2930a 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -157,6 +157,7 @@ class FetchV2Kernel { DeepCopy(src_item, fetch_var_name, dst_item); } else { dst_item->ShareDataWith(src_item); + dst_item->set_lod(src_item.lod()); } } else { auto &src_item = fetch_var->Get(); @@ -172,6 +173,7 @@ class FetchV2Kernel { DeepCopy(src_item[i], fetch_var_name, &dst_item[i]); } else { dst_item[i].ShareDataWith(src_item[i]); + dst_item[i].set_lod(src_item[i].lod()); } } } -- GitLab