未验证 提交 eca78a9f 编写于 作者: X xiongkun 提交者: GitHub

Support various length support for SelectedRows in GLOO::AllGather (#36637)

* In cpu parallel using gloo, add various length support for SelectedRows

* fix bug

* fix bugs

* fix by code review

* remove timeout
上级 db633aff
......@@ -27,6 +27,7 @@ limitations under the License. */
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/allgather.h>
#include <gloo/allgatherv.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/rendezvous/context.h>
......@@ -238,10 +239,25 @@ class GlooWrapper {
return ret;
}
// TODO(xiongkun03): support all gather array of
// NOTE(@xiongkun03): support all gather array of
// numbers with different length
// can use AllgathervOptions, may be work in different
// occasion. Need some survey.
// if the third argument is int, use allgather,
// if it is vector, use AllgathervOptions,
// which works in different length occasion.
template <typename T>
void AllGatherVector(T* input_ptr, T* output_ptr,
std::vector<size_t>& element_nums) { // NOLINT
CHECK_EQ(is_initialized_, true);
#ifdef PADDLE_WITH_GLOO
gloo::AllgathervOptions opts(context_);
opts.setInput(input_ptr, element_nums[rank_]);
opts.setOutput(output_ptr, element_nums);
gloo::allgatherv(opts);
#else
LOG(WARNING) << "AllGather does nothing when WITH_GLOO=OFF";
#endif
}
template <typename T>
void AllGatherVector(T* input_ptr, T* output_ptr,
size_t element_num) { // NOLINT
......
......@@ -53,15 +53,13 @@ void GLOOParallelContext::InitWithRingID(int ring_id) {
platform::errors::OutOfRange("Still not implement InitWithRingID"));
}
#define GLOO_CASE(type, T, gw) \
case type: { \
VLOG(4) << "Use the gloo all reduce to sync. SRC:" << src_tensor; \
std::vector<T> send_vector##T; \
framework::TensorToVector<T>(src_tensor, &send_vector##T); \
auto recv_vector##T = gw->AllReduce<T>(send_vector##T); \
framework::TensorFromVector<T>(recv_vector##T, dst_tensor); \
VLOG(4) << "DST:" << *dst_tensor; \
break; \
#define GLOO_CASE(type, T, gw) \
case type: { \
std::vector<T> send_vector##T; \
framework::TensorToVector<T>(src_tensor, &send_vector##T); \
auto recv_vector##T = gw->AllReduce<T>(send_vector##T); \
framework::TensorFromVector<T>(recv_vector##T, dst_tensor); \
break; \
}
void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
......@@ -118,7 +116,7 @@ void GLOOParallelContext::AllReduce(const framework::Tensor &src_tensor,
const auto *src_tensor_ptr = src_tensor.data<T>(); \
gw->AllGatherVector<T>(const_cast<T *>(src_tensor_ptr), \
reinterpret_cast<T *>(dst_tensor_ptr), \
value_sendcount); \
element_nums); \
break; \
}
......@@ -150,48 +148,31 @@ void GLOOParallelContext::AllReduce(const framework::SelectedRows &src,
auto *dst_rows_ptr = dst_rows->MutableData(place);
const int64_t *src_rows_ptr = src_rows.Data(place);
// VLOG(3) << "Selected Rows of src:" << string::join_strings(dst_rows, ',')
auto *dst_tensor = dst->mutable_value();
auto dims = src_tensor.dims();
dims[0] = rows_num;
auto feature_size = framework::product(dims) / dims[0];
dst_tensor->Resize(dims);
if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + nranks,
[&](size_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
// Because gloo wrapper utility class currently don't support
// broadcast, so we only deal the-same case.
VLOG(3) << "Use the gloo all reduce to sync. SRC:" << src_tensor;
// framework::SerializeToStream(VLOG(4), src);
VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
auto value_sendcount = cpu_rows_num_ptr[0] * feature_size;
auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype);
gloo_wrapper->AllGatherVector<int64_t>(const_cast<int64_t *>(src_rows_ptr),
static_cast<int64_t *>(dst_rows_ptr),
rows_num_vector[0]);
switch (dtype) {
GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP32, float,
gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP64, double,
gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT32, int, gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT64, int64_t,
gloo_wrapper);
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid datatype for allreduce"));
}
std::vector<size_t> element_nums = rows_num_vector;
std::for_each(element_nums.begin(), element_nums.end(),
[feature_size](size_t &x) { x = x * feature_size; });
auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype);
gloo_wrapper->AllGatherVector<int64_t>(const_cast<int64_t *>(src_rows_ptr),
static_cast<int64_t *>(dst_rows_ptr),
rows_num_vector);
switch (dtype) {
GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP32, float, gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP64, double, gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT32, int, gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT64, int64_t,
gloo_wrapper);
default: {
PADDLE_THROW(
platform::errors::InvalidArgument("Invalid datatype for allreduce"));
}
VLOG(3) << "Selected Row DST:" << *dst_tensor;
VLOG(3) << "Selected Rows of DST:"
<< string::join_strings(std::vector<int64_t>(*dst_rows), ',');
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The number of each card is not the same, gloo only support the-same"
"batch division"));
}
}
......
......@@ -214,6 +214,7 @@ if (NOT WITH_GLOO)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_unused_variables_gloo)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height_gloo)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_gloo)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_diff_length_gloo)
endif()
if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
......
......@@ -515,10 +515,28 @@ class TestParallelDyGraphRunnerBase(object):
return batch
elif args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
# NOTE(@xiongkun03) args.diff_batch means batch length is different:
# such as : batch = [2,3,4,5], then the first rank will get [2] and
# the second rank will get [3,4,5].
# this function is for test sparse_embedding_differ_length
if hasattr(args, "diff_batch") and args.diff_batch:
assert len(
batch) > 2, "in differ_batch mode, len(batch) must > 2."
if paddle.distributed.get_rank() == 0:
new_batch.append(batch[0])
elif paddle.distributed.get_rank() == 1:
new_batch.extend([_ for _ in batch[1:]])
else:
raise NotImplementedError(
"Current TestParallelDyGraphRunnerBase don't support world_size > 2"
)
return new_batch
else:
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
......@@ -699,6 +717,7 @@ def runtime_main(test_class):
parser.add_argument('--use_fleet_api', action='store_true')
parser.add_argument('--use_fleet_api_20', action='store_true')
parser.add_argument('--use_local_sgd', action='store_true')
parser.add_argument('--diff_batch', action='store_true')
parser.add_argument('--ut4grad_allreduce', action='store_true')
parser.add_argument(
'--hallreduce_inter_nranks', type=int, required=False, default=2)
......@@ -798,6 +817,7 @@ class TestDistBase(unittest.TestCase):
self._gloo_mode = False # now, support gloo backend
self._pipeline_mode = False
self._mp_mode = False
self._diff_batch = False
# FIXME(typhoonzero): I added this stupid argument to enable
# testing allreduce layers, which users can call layers.allreduce
# to accumulate tensors at anywhere. Find a better way to do this
......@@ -1100,6 +1120,8 @@ class TestDistBase(unittest.TestCase):
#assert self._use_reader_alloc == False, "gloo not support _use_reduce"
if self._save_model:
tr_cmd += " --save_model"
if self._diff_batch:
tr_cmd += " --diff_batch"
self.__use_cuda = False
self.__use_xpu = False
assert self.__use_cuda == False, "gloo not support use cuda"
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_sparse_embedding import TestSparseEmbedding
from parallel_dygraph_sparse_embedding_fp64 import TestSparseEmbeddingFP64
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphSparseEmdedding_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
self._diff_batch = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册