From bfa55c9ddb150c939f33a4de840da9c193517f49 Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Mon, 30 Sep 2019 16:23:01 +0800 Subject: [PATCH] Add place deps for fused_all_reduce_op_handle (#20077) test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 4 ++-- .../fluid/framework/details/fused_all_reduce_op_handle.cc | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 1e87eabc083..de13450d216 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -25,7 +25,7 @@ if(WITH_GPU) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - dynload_cuda variable_visitor) + dynload_cuda variable_visitor place) if(WITH_DGC) nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope @@ -46,7 +46,7 @@ else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - variable_visitor) + variable_visitor place) if(WITH_DISTRIBUTE) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim selected_rows_functor sendrecvop_rpc) diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index dce4e36e02a..ddd6d10e5e1 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -189,8 +189,10 @@ void FusedAllReduceOpHandle::GetGradLoDTensor( auto var = local_scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(var, "%s is not found in local scope.", var_name); auto &lod_tensor = var->Get(); - PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx), - "%s(%d) is not in the right place.", var_name, scope_idx); + + PADDLE_ENFORCE_EQ( + platform::is_same_place(lod_tensor.place(), places_.at(scope_idx)), + true, "%s(%d) is not in the right place.", var_name, scope_idx); grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor)); } } -- GitLab