diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 1e87eabc083b994ccb1010f5640d3aef79ee6025..de13450d216720899e77ecb61b1ce6e4317f8730 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -25,7 +25,7 @@ if(WITH_GPU) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - dynload_cuda variable_visitor) + dynload_cuda variable_visitor place) if(WITH_DGC) nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope @@ -46,7 +46,7 @@ else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - variable_visitor) + variable_visitor place) if(WITH_DISTRIBUTE) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim selected_rows_functor sendrecvop_rpc) diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index dce4e36e02a4d22724be63b8774c593463dd4567..ddd6d10e5e116d8c82bb3137f72c39d5cc735f95 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -189,8 +189,10 @@ void FusedAllReduceOpHandle::GetGradLoDTensor( auto var = local_scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(var, "%s is not found in local scope.", var_name); auto &lod_tensor = var->Get(); - PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx), - "%s(%d) is not in the right place.", var_name, scope_idx); + + PADDLE_ENFORCE_EQ( + platform::is_same_place(lod_tensor.place(), places_.at(scope_idx)), + true, "%s(%d) is not in the right place.", var_name, scope_idx); grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor)); } }