From 87197f8c2e4d002fc39027c3d4ee99f4ead0ba2c Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Mon, 8 Feb 2021 10:53:34 +0800 Subject: [PATCH] [kunlun]fix sync in multi kunlun xpu dygraph training. (#30943) --- paddle/fluid/imperative/reducer.cc | 12 ++++++++++++ .../tests/unittests/test_parallel_dygraph_mnist.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 9f296cbd5e1..8f55645b880 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -626,6 +626,18 @@ void Reducer::MarkGroupReady(size_t group_index) { // group.dense_tensors ---> group.dense_contents_ group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); +// NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support +// default stream for communicating, +// so there exist some problems in synchronization. And need to add a WaitComm +// there. +// TODO(liuyuhui): If BKCL support events, it should be fixed as non-blocking +// communication. +#ifdef PADDLE_WITH_XPU_BKCL + if (platform::is_xpu_place(group.dense_tensors_[0].place())) { + parallel_ctx_->WaitComm(run_order); + } +#endif + // Start allreduce parallel_ctx_->AllReduceByStream( group.dense_contents_, &(group.dense_contents_), run_order, false); diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py index faba479b32f..f21468f50c5 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py @@ -55,7 +55,7 @@ class TestParallelDygraphMnistXPU(TestDistBase): if fluid.core.is_compiled_with_xpu(): self.check_with_place( "parallel_dygraph_mnist.py", - delta=1e-1, + delta=1e-4, check_error_log=True, log_name=flag_name) @@ -94,7 +94,7 @@ class TestFleetDygraphMnistXPU(TestDistBase): if fluid.core.is_compiled_with_xpu(): self.check_with_place( "parallel_dygraph_mnist.py", - delta=1e-1, + delta=1e-4, check_error_log=True, log_name=flag_name) -- GitLab