未验证 提交 c5232b4b 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Support diff batch for sparse of EagerReducer (#42646)

* support diff batch for sparse of eagerreducer

* fix
上级 7b828f71
......@@ -901,6 +901,9 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
dev_ctx->Wait();
Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();
if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
[&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
......@@ -940,8 +943,6 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
&dst_rows_vector);
dev_ctx->Wait();
Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();
dst_shape[dst_shape.size() - 2] = rows_num;
auto dst_dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
paddle::experimental::full(IntArray(dst_shape), 0,
......@@ -971,8 +972,58 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
*(src->mutable_value()) =
*(std::dynamic_pointer_cast<phi::DenseTensor>(dst_value_tensor.impl()));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("This case is not supported."));
std::vector<Tensor> rows_tensors;
std::vector<Tensor> values_tensors;
for (int i = 0; i < size_; ++i) {
std::vector<int64_t> value_tensor_shape = {
cpu_rows_num_ptr[i], dst_shape[dst_shape.size() - 1]};
Tensor rows_tensor = paddle::experimental::full(
IntArray({static_cast<int64_t>(cpu_rows_num_ptr[i])}), 0,
DataType::INT64, inner_place_);
Tensor values_tensor = paddle::experimental::full(
IntArray(value_tensor_shape), 0, src->value().dtype(), inner_place_);
std::vector<phi::DenseTensor> rows_dense_vector;
std::vector<phi::DenseTensor> values_dense_vector;
if (i == rank_) {
auto *rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(rows_tensor.impl())
.get();
framework::TensorFromVector<int64_t>(src_rows, *dev_ctx,
rows_dense_tensor);
values_tensor.set_impl(
std::make_shared<phi::DenseTensor>(src->value()));
}
rows_dense_vector.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(rows_tensor.impl()));
values_dense_vector.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(values_tensor.impl()));
auto b_opts = BroadcastOptions();
b_opts.source_rank = i;
process_group_->Broadcast(rows_dense_vector, rows_dense_vector, b_opts);
process_group_
->Broadcast(values_dense_vector, values_dense_vector, b_opts)
->Wait();
rows_tensors.push_back(rows_tensor);
values_tensors.push_back(values_tensor);
}
Tensor dst_rows_tensor =
paddle::experimental::concat(rows_tensors, phi::Scalar(0));
framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
auto *dst_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(dst_rows_tensor.impl())
.get();
framework::TensorToVector<int64_t>(*dst_rows_dense_tensor, *dev_ctx,
&dst_rows_vector);
src->set_rows(dst_rows_vector);
Tensor dst_values_tensor =
paddle::experimental::concat(values_tensors, phi::Scalar(0));
*(src->mutable_value()) = *(
std::dynamic_pointer_cast<phi::DenseTensor>(dst_values_tensor.impl()));
}
}
......
......@@ -1461,6 +1461,7 @@ class TestDistBase(unittest.TestCase):
need_envs={},
log_name=""):
if self._dygraph and (self._gloo_mode or self._nccl2_mode):
need_envs.update({"FLAGS_enable_eager_mode": "1"})
with _test_eager_guard():
self.check_with_place_func(
model_file=model_file,
......@@ -1468,6 +1469,7 @@ class TestDistBase(unittest.TestCase):
check_error_log=check_error_log,
need_envs=need_envs,
log_name=log_name)
need_envs.update({"FLAGS_enable_eager_mode": "0"})
self.check_with_place_func(
model_file=model_file,
delta=delta,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册