From f86c2bc6e156d3648239536419886eec7e59323a Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Fri, 11 Oct 2019 01:22:36 -0500 Subject: [PATCH] Add place deps (#20386) test=release/1.6 --- paddle/fluid/framework/details/CMakeLists.txt | 4 ++-- .../fluid/framework/details/fused_all_reduce_op_handle.cc | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 1e87eabc08..de13450d21 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -25,7 +25,7 @@ if(WITH_GPU) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - dynload_cuda variable_visitor) + dynload_cuda variable_visitor place) if(WITH_DGC) nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope @@ -46,7 +46,7 @@ else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - variable_visitor) + variable_visitor place) if(WITH_DISTRIBUTE) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim selected_rows_functor sendrecvop_rpc) diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index dce4e36e02..ddd6d10e5e 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -189,8 +189,10 @@ void FusedAllReduceOpHandle::GetGradLoDTensor( auto var = local_scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(var, "%s is not found in local scope.", var_name); auto &lod_tensor = var->Get(); - PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx), - "%s(%d) is not in the right place.", var_name, scope_idx); + + PADDLE_ENFORCE_EQ( + platform::is_same_place(lod_tensor.place(), places_.at(scope_idx)), + true, "%s(%d) is not in the right place.", var_name, scope_idx); grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor)); } } -- GitLab