diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index d5ca061944f33939cea59a5275e691b1966194fa..1d9f1bd6e417e30f0799f0bbed1739cedb4e8fbf 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() { int root_id = boost::get(in_tensor.place()).device; std::vector> broadcast_calls; + int type = platform::ToNCCLDataType(in_tensor.type()); + size_t numel = static_cast(in_tensor.numel()); + for (auto out_var_handle : out_var_handles) { Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) ->FindVar(out_var_handle->name_); @@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() { send_recv_buffer = const_cast(in_tensor.data()); out_handle = out_var_handle; } else { - send_recv_buffer = - VariableVisitor::GetMutableTensor(out_var).mutable_data( - out_var_handle->place_); + send_recv_buffer = VariableVisitor::GetMutableTensor(out_var) + .Resize(in_tensor.dims()) + .mutable_data(out_var_handle->place_); } - int type = platform::ToNCCLDataType(in_tensor.type()); - size_t numel = static_cast(in_tensor.numel()); broadcast_calls.emplace_back( [send_recv_buffer, numel, type, root_id, &nccl_ctx] { PADDLE_ENFORCE(platform::dynload::ncclBcast(