未验证 提交 806f8d2b 编写于 作者: X XiaociZhang 提交者: GitHub

[Kunlun] Modify some legacy code on distributed training (#55515)

* [Kunlun] Mofify some legacy code on distributed training

There were limitations on XPUs before, such as concat/split is not
supported, and c_broadcast only support fp32. These limitations are
lifted recently.

Multi-device profiling on XPU will also be supported by this PR.
Without this PR, a hanging broadcast will be issued by devices that
enables profiling, eventually lead to kernel timeout error.

* fix typo
上级 284e0d12
...@@ -872,14 +872,8 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -872,14 +872,8 @@ void ParallelExecutor::BCastParamsToDevices(
std::vector<void *> buffers; std::vector<void *> buffers;
buffers.reserve(member_->places_.size()); buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
// TODO(liuyuhui): BKCL only support parameters using float type, auto dtype = framework::TransToProtoVarType(main_tensor.dtype());
// other parameters need to be strongly converted to float before BKCLDataType data_type = platform::ToBKCLDataType(dtype);
// broadcasting,
// but broadcast is equivalent to no type of operation, does not affect
// correctness.
BKCLDataType data_type = BKCL_FLOAT;
// BKCLDataType data_type =
// platform::ToBKCLDataType(framework::TransToProtoVarType(main_tensor.dtype()));
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
...@@ -904,33 +898,21 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -904,33 +898,21 @@ void ParallelExecutor::BCastParamsToDevices(
member_->places_.size())); member_->places_.size()));
{ {
auto *bkcl_ctxs = member_->bkcl_ctxs_->DefaultFlatCtx(); auto *bkcl_ctxs = member_->bkcl_ctxs_->DefaultFlatCtx();
platform::BKCLGroupGuard guard;
PADDLE_ENFORCE_EQ(
bkcl_group_start(),
BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_group_start failed"));
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[i]); auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[i]);
auto broadcast_numel = numel;
if (framework::TransToProtoVarType(main_tensor.dtype()) ==
framework::proto::VarType::INT64) {
broadcast_numel *= 2;
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bkcl_broadcast(bkcl_ctx.comm(), bkcl_broadcast(bkcl_ctx.comm(),
buffers[i], buffers[i],
buffers[i], buffers[i],
broadcast_numel, numel,
data_type, data_type,
0, 0,
NULL), NULL),
BKCL_SUCCESS, BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_broadcast failed")); platform::errors::Unavailable("bkcl_broadcast failed"));
} }
PADDLE_ENFORCE_EQ( bkcl_ctxs->WaitAll();
bkcl_group_end(),
BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_group_end failed"));
} }
#else #else
PADDLE_THROW( PADDLE_THROW(
......
...@@ -819,8 +819,6 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { ...@@ -819,8 +819,6 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
} }
} }
// TODO(liuyuhui): If BKCL support non-blocking communication, it should be
// fixed as same as multi gpus card training.
void Reducer::MarkGroupReady(size_t group_index) { void Reducer::MarkGroupReady(size_t group_index) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
group_index, group_index,
......
...@@ -397,10 +397,7 @@ class DataParallel(layers.Layer): ...@@ -397,10 +397,7 @@ class DataParallel(layers.Layer):
), "ProcessGroup must be an instance of Group in DataParallel." ), "ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params # sync buffer and params
# TODO(liuyuhui) Currently not support xpu. xpu is sync_params_buffers(self._layers)
# still broadcasting parameters when calling layer
if not paddle.is_compiled_with_xpu():
sync_params_buffers(self._layers)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024) self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control # NOTE(shenliang03): We can set environment variables to control
......
...@@ -1199,8 +1199,6 @@ class Layer: ...@@ -1199,8 +1199,6 @@ class Layer:
pass pass
def _dygraph_call_func(self, *inputs, **kwargs): def _dygraph_call_func(self, *inputs, **kwargs):
from paddle.distributed import parallel_helper
for forward_pre_hook in self._forward_pre_hooks.values(): for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs) hook_result = forward_pre_hook(self, inputs)
if hook_result is not None: if hook_result is not None:
...@@ -1212,17 +1210,6 @@ class Layer: ...@@ -1212,17 +1210,6 @@ class Layer:
with program_desc_tracing_guard(False): with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs) self._build_once(*inputs, **kwargs)
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if (
parallel_helper._is_data_parallel_mode()
and paddle.is_compiled_with_xpu()
):
parallel_helper._broadcast_parameters(
self._parameters.values()
)
self._built = True self._built = True
if in_profiler_mode(): if in_profiler_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册