From c5232b4b537b5eb17f82195221cb57b63a9f5ebd Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Wed, 11 May 2022 11:11:46 +0800 Subject: [PATCH] [Dygraph] Support diff batch for sparse of EagerReducer (#42646) * support diff batch for sparse of eagerreducer * fix --- .../fluid/distributed/collective/reducer.cc | 59 +++++++++++++++++-- .../fluid/tests/unittests/test_dist_base.py | 2 + 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index a7c3e2208a..96009ce722 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -901,6 +901,9 @@ void EagerReducer::AllReduceSparse(EagerGroup *group, dev_ctx->Wait(); + Tensor src_value_tensor(std::make_shared(src->value())); + std::vector dst_shape = src_value_tensor.shape(); + if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_, [&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) { // During sparse communication, the number of each card is same. @@ -940,8 +943,6 @@ void EagerReducer::AllReduceSparse(EagerGroup *group, &dst_rows_vector); dev_ctx->Wait(); - Tensor src_value_tensor(std::make_shared(src->value())); - std::vector dst_shape = src_value_tensor.shape(); dst_shape[dst_shape.size() - 2] = rows_num; auto dst_dense_tensor = std::dynamic_pointer_cast( paddle::experimental::full(IntArray(dst_shape), 0, @@ -971,8 +972,58 @@ void EagerReducer::AllReduceSparse(EagerGroup *group, *(src->mutable_value()) = *(std::dynamic_pointer_cast(dst_value_tensor.impl())); } else { - PADDLE_THROW( - platform::errors::Unimplemented("This case is not supported.")); + std::vector rows_tensors; + std::vector values_tensors; + + for (int i = 0; i < size_; ++i) { + std::vector value_tensor_shape = { + cpu_rows_num_ptr[i], dst_shape[dst_shape.size() - 1]}; + Tensor rows_tensor = paddle::experimental::full( + IntArray({static_cast(cpu_rows_num_ptr[i])}), 0, + DataType::INT64, inner_place_); + Tensor values_tensor = paddle::experimental::full( + IntArray(value_tensor_shape), 0, src->value().dtype(), inner_place_); + std::vector rows_dense_vector; + std::vector values_dense_vector; + + if (i == rank_) { + auto *rows_dense_tensor = + std::dynamic_pointer_cast(rows_tensor.impl()) + .get(); + framework::TensorFromVector(src_rows, *dev_ctx, + rows_dense_tensor); + values_tensor.set_impl( + std::make_shared(src->value())); + } + rows_dense_vector.push_back( + *std::dynamic_pointer_cast(rows_tensor.impl())); + values_dense_vector.push_back( + *std::dynamic_pointer_cast(values_tensor.impl())); + + auto b_opts = BroadcastOptions(); + b_opts.source_rank = i; + process_group_->Broadcast(rows_dense_vector, rows_dense_vector, b_opts); + process_group_ + ->Broadcast(values_dense_vector, values_dense_vector, b_opts) + ->Wait(); + rows_tensors.push_back(rows_tensor); + values_tensors.push_back(values_tensor); + } + + Tensor dst_rows_tensor = + paddle::experimental::concat(rows_tensors, phi::Scalar(0)); + framework::Vector dst_rows_vector(rows_num, 0); + auto *dst_rows_dense_tensor = + std::dynamic_pointer_cast(dst_rows_tensor.impl()) + .get(); + framework::TensorToVector(*dst_rows_dense_tensor, *dev_ctx, + &dst_rows_vector); + src->set_rows(dst_rows_vector); + + Tensor dst_values_tensor = + paddle::experimental::concat(values_tensors, phi::Scalar(0)); + *(src->mutable_value()) = *( + std::dynamic_pointer_cast(dst_values_tensor.impl())); } } diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 11972059c8..4f21b3220a 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -1461,6 +1461,7 @@ class TestDistBase(unittest.TestCase): need_envs={}, log_name=""): if self._dygraph and (self._gloo_mode or self._nccl2_mode): + need_envs.update({"FLAGS_enable_eager_mode": "1"}) with _test_eager_guard(): self.check_with_place_func( model_file=model_file, @@ -1468,6 +1469,7 @@ class TestDistBase(unittest.TestCase): check_error_log=check_error_log, need_envs=need_envs, log_name=log_name) + need_envs.update({"FLAGS_enable_eager_mode": "0"}) self.check_with_place_func( model_file=model_file, delta=delta, -- GitLab