未验证 提交 806f8d2b 编写于 作者: X XiaociZhang 提交者: GitHub

[Kunlun] Modify some legacy code on distributed training (#55515)

* [Kunlun] Mofify some legacy code on distributed training

There were limitations on XPUs before, such as concat/split is not
supported, and c_broadcast only support fp32. These limitations are
lifted recently.

Multi-device profiling on XPU will also be supported by this PR.
Without this PR, a hanging broadcast will be issued by devices that
enables profiling, eventually lead to kernel timeout error.

* fix typo
上级 284e0d12
......@@ -872,14 +872,8 @@ void ParallelExecutor::BCastParamsToDevices(
std::vector<void *> buffers;
buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel();
// TODO(liuyuhui): BKCL only support parameters using float type,
// other parameters need to be strongly converted to float before
// broadcasting,
// but broadcast is equivalent to no type of operation, does not affect
// correctness.
BKCLDataType data_type = BKCL_FLOAT;
// BKCLDataType data_type =
// platform::ToBKCLDataType(framework::TransToProtoVarType(main_tensor.dtype()));
auto dtype = framework::TransToProtoVarType(main_tensor.dtype());
BKCLDataType data_type = platform::ToBKCLDataType(dtype);
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i];
void *buffer;
......@@ -904,33 +898,21 @@ void ParallelExecutor::BCastParamsToDevices(
member_->places_.size()));
{
auto *bkcl_ctxs = member_->bkcl_ctxs_->DefaultFlatCtx();
PADDLE_ENFORCE_EQ(
bkcl_group_start(),
BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_group_start failed"));
platform::BKCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[i]);
auto broadcast_numel = numel;
if (framework::TransToProtoVarType(main_tensor.dtype()) ==
framework::proto::VarType::INT64) {
broadcast_numel *= 2;
}
PADDLE_ENFORCE_EQ(
bkcl_broadcast(bkcl_ctx.comm(),
buffers[i],
buffers[i],
broadcast_numel,
numel,
data_type,
0,
NULL),
BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_broadcast failed"));
}
PADDLE_ENFORCE_EQ(
bkcl_group_end(),
BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_group_end failed"));
bkcl_ctxs->WaitAll();
}
#else
PADDLE_THROW(
......
......@@ -819,8 +819,6 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
}
}
// TODO(liuyuhui): If BKCL support non-blocking communication, it should be
// fixed as same as multi gpus card training.
void Reducer::MarkGroupReady(size_t group_index) {
PADDLE_ENFORCE_GE(
group_index,
......
......@@ -397,10 +397,7 @@ class DataParallel(layers.Layer):
), "ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params
# TODO(liuyuhui) Currently not support xpu. xpu is
# still broadcasting parameters when calling layer
if not paddle.is_compiled_with_xpu():
sync_params_buffers(self._layers)
sync_params_buffers(self._layers)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
......
......@@ -1199,8 +1199,6 @@ class Layer:
pass
def _dygraph_call_func(self, *inputs, **kwargs):
from paddle.distributed import parallel_helper
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
......@@ -1212,17 +1210,6 @@ class Layer:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if (
parallel_helper._is_data_parallel_mode()
and paddle.is_compiled_with_xpu()
):
parallel_helper._broadcast_parameters(
self._parameters.values()
)
self._built = True
if in_profiler_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册