提交 bfa55c9d 编写于 作者: C chengduo 提交者: Tao Luo

Add place deps for fused_all_reduce_op_handle (#20077)

test=develop
上级 728ec1b4
......@@ -25,7 +25,7 @@ if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
dynload_cuda variable_visitor place)
if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
......@@ -46,7 +46,7 @@ else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
variable_visitor place)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_rpc)
......
......@@ -189,8 +189,10 @@ void FusedAllReduceOpHandle::GetGradLoDTensor(
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var, "%s is not found in local scope.", var_name);
auto &lod_tensor = var->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx),
"%s(%d) is not in the right place.", var_name, scope_idx);
PADDLE_ENFORCE_EQ(
platform::is_same_place(lod_tensor.place(), places_.at(scope_idx)),
true, "%s(%d) is not in the right place.", var_name, scope_idx);
grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册