From 8f0b3c05162cf28171a8ff2b93ca7c03088732ad Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 27 Sep 2019 13:14:01 +0800 Subject: [PATCH] the integrated communicator (#19849) * add a base class for the Communicator * add AsyncCommunicator Impl for async distributed training --- .../details/async_ssa_graph_executor.cc | 11 +- .../operators/distributed/communicator.cc | 189 +++++++++--------- .../operators/distributed/communicator.h | 114 ++++++++--- .../operators/distributed_ops/send_op.cc | 9 +- paddle/fluid/pybind/communicator_py.cc | 3 +- .../fluid/tests/unittests/test_dist_base.py | 55 +++-- .../fluid/tests/unittests/test_dist_ctr.py | 39 +++- .../tests/unittests/test_dist_fleet_ctr.py | 1 - .../tests/unittests/test_dist_simnet_bow.py | 2 +- 9 files changed, 259 insertions(+), 164 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 2e24707539..44e7b16bef 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -86,13 +86,10 @@ void ProcessGraph(std::vector graphs, Scope *scope) { if (send_varname_to_ctx.size() > 0) { VLOG(3) << "this is distribute mode, will use communicator"; - if (operators::distributed::Communicator::GetInstance() == nullptr) { - operators::distributed::Communicator::Init(send_varname_to_ctx, - recv_varname_to_ctx, scope); - operators::distributed::Communicator::GetInstance()->Start(); - } else { - VLOG(3) << "communicator has been initialized, skip"; - } + auto *instance = operators::distributed::Communicator::InitInstance< + operators::distributed::AsyncCommunicator>(send_varname_to_ctx, + recv_varname_to_ctx, scope); + if (!instance->IsRunning()) instance->Start(); } #endif } diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 683d4ca98a..a212aca34a 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -52,14 +52,16 @@ inline double GetCurrentUS() { return 1e+6 * time.tv_sec + time.tv_usec; } +std::once_flag Communicator::init_flag_; std::shared_ptr Communicator::communicator_(nullptr); -Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, - const RpcCtxMap &recv_varname_to_ctx, - Scope *recv_scope) - : send_varname_to_ctx_(send_varname_to_ctx), - recv_varname_to_ctx_(recv_varname_to_ctx), - recv_scope_(recv_scope) { +void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RpcCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + // get all send information from graph, build vars_to_send VLOG(0) << "communicator_independent_recv_thread: " << FLAGS_communicator_independent_recv_thread; @@ -98,7 +100,51 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, } } -Communicator::~Communicator() { +void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, + Scope *param_scope) { + using RpcCtxMap = operators::distributed::RpcCtxMap; + VLOG(3) << "ProcessGraph"; + RpcCtxMap send_varname_to_ctx; + RpcCtxMap recv_varname_to_ctx; + for (auto *op : program.Block(0).AllOps()) { + VLOG(3) << "node name " << op->Type(); + if (op->Type() == "send") { + auto send_var_name = op->Input("X")[0]; + auto send_varnames = boost::get>( + op->GetNullableAttr("send_varnames")); + auto epmap = + boost::get>(op->GetNullableAttr("epmap")); + auto height_section = + boost::get>(op->GetNullableAttr("sections")); + auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); + send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( + send_var_name, send_varnames, epmap, height_section, trainer_id); + VLOG(3) << "find and init an send op: " + << send_varname_to_ctx[send_var_name]; + } else if (op->Type() == "recv") { + auto do_not_run = boost::get(op->GetNullableAttr("do_not_run")); + PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!"); + auto recv_var_name = op->Output("Out")[0]; + auto recv_varnames = boost::get>( + op->GetNullableAttr("recv_varnames")); + auto epmap = + boost::get>(op->GetNullableAttr("epmap")); + auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); + recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext( + recv_var_name, recv_varnames, epmap, {}, trainer_id); + } + } + + // init communicator here + if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) { + LOG(WARNING) << "no var need to send and recv!!"; + } + + operators::distributed::AsyncCommunicator::InitImpl( + send_varname_to_ctx, recv_varname_to_ctx, param_scope); +} + +AsyncCommunicator::~AsyncCommunicator() { if (FLAGS_v >= 3) { std::string msg("~Communicator"); fwrite(msg.c_str(), msg.length(), 1, stdout); @@ -112,7 +158,7 @@ Communicator::~Communicator() { } } -void Communicator::SendThread() { +void AsyncCommunicator::SendThread() { VLOG(3) << "SendThread start!"; while (running_) { std::vector> task_futures; @@ -175,50 +221,12 @@ void Communicator::SendThread() { VLOG(3) << "run send graph use time " << after_run_send_graph - before_run_send_graph; - RecvNonIndependent(); + Recv(); } VLOG(0) << "communicator stopped, send thread exit"; } -void Communicator::RecvNonIndependent() { - if (FLAGS_communicator_independent_recv_thread) { - return; - } - - auto grad_num = grad_num_.load(); - if (grad_num > 0) { - RecvAll(); - grad_num_.store(0); - } else { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } -} - -void Communicator::RecvAll() { - VLOG(3) << "parallel run recv graph"; - if (!running_) return; - auto before_send = GetCurrentUS(); - std::vector> task_futures; - task_futures.reserve(recv_varname_to_ctx_.size()); - for (auto &iter : recv_varname_to_ctx_) { - auto recv_task = [this, &iter] { - auto &var_name = iter.first; - VLOG(4) << "recv var " << var_name; - auto recv_functor = distributed::ParameterRecv(); - if (!FLAGS_communicator_fake_rpc) { - recv_functor(iter.second, *recv_scope_); - } - }; - task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); - } - for (auto &task : task_futures) { - task.wait(); - } - auto after_recv = GetCurrentUS(); - VLOG(1) << "run recv graph use time " << after_recv - before_send; -} - -void Communicator::RecvThread() { +void AsyncCommunicator::RecvThread() { VLOG(3) << "RecvThread start!"; while (running_) { auto grad_num = grad_num_.load(); @@ -233,8 +241,8 @@ void Communicator::RecvThread() { VLOG(0) << "communicator stopped, recv thread exit"; } -void Communicator::Send(const std::string &var_name, - const framework::Scope &scope) { +void AsyncCommunicator::Send(const std::string &var_name, + const framework::Scope &scope) { VLOG(3) << "communicator send " << var_name; // push var into send queue by var_name auto *grad_var = scope.FindVar(var_name); @@ -255,56 +263,45 @@ void Communicator::Send(const std::string &var_name, } } -void Communicator::Init(const paddle::framework::ProgramDesc &program, - Scope *param_scope) { - using RpcCtxMap = operators::distributed::RpcCtxMap; - VLOG(3) << "ProcessGraph"; - RpcCtxMap send_varname_to_ctx; - RpcCtxMap recv_varname_to_ctx; - for (auto *op : program.Block(0).AllOps()) { - VLOG(3) << "node name " << op->Type(); - if (op->Type() == "send") { - auto send_var_name = op->Input("X")[0]; - auto send_varnames = boost::get>( - op->GetNullableAttr("send_varnames")); - auto epmap = - boost::get>(op->GetNullableAttr("epmap")); - auto height_section = - boost::get>(op->GetNullableAttr("sections")); - auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); - send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( - send_var_name, send_varnames, epmap, height_section, trainer_id); - VLOG(3) << "find and init an send op: " - << send_varname_to_ctx[send_var_name]; - } else if (op->Type() == "recv") { - auto do_not_run = boost::get(op->GetNullableAttr("do_not_run")); - PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!"); - auto recv_var_name = op->Output("Out")[0]; - auto recv_varnames = boost::get>( - op->GetNullableAttr("recv_varnames")); - auto epmap = - boost::get>(op->GetNullableAttr("epmap")); - auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); - recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext( - recv_var_name, recv_varnames, epmap, {}, trainer_id); - } +void AsyncCommunicator::Recv() { + if (FLAGS_communicator_independent_recv_thread) { + return; } - // init communicator here - if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) { - LOG(WARNING) << "no var need to send and recv!!"; + auto grad_num = grad_num_.load(); + if (grad_num > 0) { + RecvAll(); + grad_num_.store(0); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - operators::distributed::Communicator::Init(send_varname_to_ctx, - recv_varname_to_ctx, param_scope); } -Communicator *Communicator::GetInstance() { return communicator_.get(); } - -std::shared_ptr Communicator::GetInstantcePtr() { - return communicator_; +void AsyncCommunicator::RecvAll() { + VLOG(3) << "parallel run recv graph"; + if (!running_) return; + auto before_send = GetCurrentUS(); + std::vector> task_futures; + task_futures.reserve(recv_varname_to_ctx_.size()); + for (auto &iter : recv_varname_to_ctx_) { + auto recv_task = [this, &iter] { + auto &var_name = iter.first; + VLOG(4) << "recv var " << var_name; + auto recv_functor = distributed::ParameterRecv(); + if (!FLAGS_communicator_fake_rpc) { + recv_functor(iter.second, *recv_scope_); + } + }; + task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); + } + for (auto &task : task_futures) { + task.wait(); + } + auto after_recv = GetCurrentUS(); + VLOG(1) << "run recv graph use time " << after_recv - before_send; } -void Communicator::Start() { +void AsyncCommunicator::Start() { VLOG(0) << "Communicator start"; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; @@ -313,15 +310,15 @@ void Communicator::Start() { running_ = true; // start send and recv thread send_thread_.reset( - new std::thread(std::bind(&Communicator::SendThread, this))); + new std::thread(std::bind(&AsyncCommunicator::SendThread, this))); if (FLAGS_communicator_independent_recv_thread) { recv_thread_.reset( - new std::thread(std::bind(&Communicator::RecvThread, this))); + new std::thread(std::bind(&AsyncCommunicator::RecvThread, this))); } } } -void Communicator::Stop() { +void AsyncCommunicator::Stop() { VLOG(0) << "Communicator stop"; running_ = false; if (!communicator_) { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index b3079f51c4..df1096e67e 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -161,27 +161,97 @@ using RpcCtxMap = std::unordered_map; class Communicator { public: - Communicator(const RpcCtxMap& send_varname_to_ctx, - const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope); + Communicator() {} + virtual ~Communicator() {} - ~Communicator(); + virtual void Start() = 0; + virtual void Stop() = 0; + virtual bool IsRunning() { return running_; } - void Start(); - void Stop(); + virtual void Send(const std::string& var_name, + const framework::Scope& scope) = 0; + virtual void Recv() = 0; - bool IsRunning() { return running_; } + virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, + Scope* recv_scope) = 0; - // send grad - void Send(const std::string& var_name, const framework::Scope& scope); + virtual void InitImpl(const paddle::framework::ProgramDesc& program, + Scope* recv_scope) = 0; - private: - // recv all parameter + static Communicator* GetInstance() { return communicator_.get(); } + + static std::shared_ptr GetInstantcePtr() { + return communicator_; + } + + template + static Communicator* InitInstance(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, + Scope* recv_scope) { + std::call_once(init_flag_, &Communicator::InitWithRpcCtx, + send_varname_to_ctx, recv_varname_to_ctx, recv_scope); + return communicator_.get(); + } + + // Init is called by InitInstance. + template + static void InitWithRpcCtx(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, + Scope* recv_scope) { + if (communicator_.get() == nullptr) { + communicator_.reset(new T()); + communicator_->InitImpl(send_varname_to_ctx, recv_varname_to_ctx, + recv_scope); + } + } + + template + static Communicator* InitInstance( + const paddle::framework::ProgramDesc& program, Scope* recv_scope) { + std::call_once(init_flag_, &Communicator::InitWithProgram, program, + recv_scope); + return communicator_.get(); + } + + template + static void InitWithProgram(const paddle::framework::ProgramDesc& program, + Scope* recv_scope) { + if (communicator_.get() == nullptr) { + communicator_.reset(new T()); + communicator_->InitImpl(program, recv_scope); + } + } + + protected: + bool running_ = false; + static std::shared_ptr communicator_; + static std::once_flag init_flag_; +}; + +class AsyncCommunicator : public Communicator { + public: + AsyncCommunicator() {} + ~AsyncCommunicator(); + void Start() override; + void Stop() override; + + void Send(const std::string& var_name, + const framework::Scope& scope) override; + void Recv() override; void RecvAll(); - void RecvNonIndependent(); + + void InitImpl(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, + Scope* recv_scope) override; + + void InitImpl(const paddle::framework::ProgramDesc& program, + Scope* recv_scope) override; + void SendThread(); void RecvThread(); - bool running_ = false; + private: std::unordered_map>>> send_varname_to_queue_; @@ -194,26 +264,6 @@ class Communicator { std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv - - // the following code is for initialize the commnunicator - public: - static void Init(const RpcCtxMap& send_varname_to_ctx, - const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) { - if (communicator_ == nullptr) { - communicator_.reset(new Communicator(send_varname_to_ctx, - recv_varname_to_ctx, recv_scope)); - } - } - - static void Init(const paddle::framework::ProgramDesc& program, - Scope* param_scope); - - static Communicator* GetInstance(); - - static std::shared_ptr GetInstantcePtr(); - - private: - static std::shared_ptr communicator_; }; } // namespace distributed diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index acb25b17d5..21f5d87fbe 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -48,14 +48,7 @@ class SendOp : public framework::OperatorBase { if (send_varnames.size() > 0) { PADDLE_ENFORCE_EQ(ins.size(), 1, ""); - if (distributed::Communicator::GetInstance() == nullptr) { - auto send_functor = distributed::ParameterSend(); - auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections, trainer_id); - send_functor(rpc_ctx, scope, true); - } else { - distributed::Communicator::GetInstance()->Send(ins[0], scope); - } + distributed::Communicator::GetInstance()->Send(ins[0], scope); } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); diff --git a/paddle/fluid/pybind/communicator_py.cc b/paddle/fluid/pybind/communicator_py.cc index 5b576f06da..183ac4d217 100644 --- a/paddle/fluid/pybind/communicator_py.cc +++ b/paddle/fluid/pybind/communicator_py.cc @@ -26,6 +26,7 @@ namespace py = pybind11; using paddle::framework::ProgramDesc; using paddle::operators::distributed::Communicator; +using paddle::operators::distributed::AsyncCommunicator; using paddle::framework::Scope; namespace paddle { @@ -36,7 +37,7 @@ void BindCommunicator(py::module* m) { py::class_>(*m, "DistCommunicator") .def(py::init([](const ProgramDesc& program, Scope* param_scope) { - Communicator::Init(program, param_scope); + Communicator::InitInstance(program, param_scope); return Communicator::GetInstantcePtr(); })) .def("stop", &Communicator::Stop) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index c9230b68fe..1c697e4e66 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -75,11 +75,14 @@ class TestDistRunnerBase(object): sync_mode, dc_asgd=False, current_endpoint=None, - nccl_comm_num=1): + nccl_comm_num=1, + hogwild_mode=False): # NOTE: import fluid until runtime, or else forking processes will cause error. config = fluid.DistributeTranspilerConfig() config.enable_dc_asgd = dc_asgd config.sync_mode = sync_mode + config.runtime_split_send_recv = hogwild_mode + if nccl_comm_num > 1: config.nccl_comm_num = nccl_comm_num # config.runtime_split_send_recv = True @@ -89,6 +92,7 @@ class TestDistRunnerBase(object): program=main_program, pservers=pserver_endpoints, trainers=trainers, + sync_mode=sync_mode, current_endpoint=current_endpoint) return t @@ -96,9 +100,15 @@ class TestDistRunnerBase(object): self.lr = args.lr self.get_model(batch_size=args.batch_size) # NOTE: pserver should not call memory optimize - t = self.get_transpiler(args.trainer_id, - fluid.default_main_program(), args.endpoints, - args.trainers, args.sync_mode, args.dc_asgd) + + t = self.get_transpiler( + trainer_id=args.trainer_id, + main_program=fluid.default_main_program(), + pserver_endpoints=args.endpoints, + trainers=args.trainers, + sync_mode=args.sync_mode, + dc_asgd=args.dc_asgd, + hogwild_mode=args.hogwild) pserver_prog = t.get_pserver_program(args.current_endpoint) startup_prog = t.get_startup_program(args.current_endpoint, pserver_prog) @@ -120,7 +130,7 @@ class TestDistRunnerBase(object): dist_strategy = DistributedStrategy() dist_strategy.exec_strategy = exec_strategy - dist_strategy.fuse_memory_size = 1 #MB + dist_strategy.fuse_memory_size = 1 # MB dist_strategy.fuse_laryer_size = 1 if args.use_local_sgd: dist_strategy.use_local_sgd = True @@ -130,11 +140,11 @@ class TestDistRunnerBase(object): role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) print_to_err("gpu_fleet", "fleet.node_num:") - #"fleet.node_id:", fleet.node_id(), - #"fleet.trainer_num:", fleet.worker_num()) + # "fleet.node_id:", fleet.node_id(), + # "fleet.trainer_num:", fleet.worker_num()) test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ - self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy) + self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy) trainer_prog = fleet._origin_program dist_prog = fleet.main_program @@ -196,10 +206,15 @@ class TestDistRunnerBase(object): print_to_err( type(self).__name__, "begin to run transpile on trainer with pserver mode") - t = self.get_transpiler(args.trainer_id, - fluid.default_main_program(), - args.endpoints, args.trainers, - args.sync_mode, args.dc_asgd) + t = self.get_transpiler( + trainer_id=args.trainer_id, + main_program=fluid.default_main_program(), + pserver_endpoints=args.endpoints, + trainers=args.trainers, + sync_mode=args.sync_mode, + dc_asgd=args.dc_asgd, + hogwild_mode=args.hogwild) + trainer_prog = t.get_trainer_program() print_to_err( type(self).__name__, @@ -251,6 +266,9 @@ class TestDistRunnerBase(object): build_stra.enable_inplace = False build_stra.memory_optimize = False + if args.hogwild: + build_stra.async_mode = True + if args.enable_backward_deps: build_stra.enable_backward_optimizer_op_deps = True @@ -411,6 +429,7 @@ def runtime_main(test_class): parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--dc_asgd', action='store_true') + parser.add_argument('--hogwild', action='store_true') parser.add_argument( '--use_reader_alloc', action='store_true', required=False) parser.add_argument('--batch_size', required=False, type=int, default=2) @@ -467,6 +486,7 @@ class TestDistBase(unittest.TestCase): self._find_free_port(), self._find_free_port()) self._python_interp = sys.executable self._sync_mode = True + self._hogwild_mode = False self._enforce_place = None self._use_reduce = False self._dc_asgd = False # must use with async mode @@ -630,6 +650,9 @@ class TestDistBase(unittest.TestCase): if self._sync_mode: tr0_cmd += " --sync_mode" tr1_cmd += " --sync_mode" + if self._hogwild_mode: + tr0_cmd += " --hogwild" + tr1_cmd += " --hogwild" if self._use_reduce: tr0_cmd += " --use_reduce" tr1_cmd += " --use_reduce" @@ -703,8 +726,8 @@ class TestDistBase(unittest.TestCase): tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f" tr_cmd = tr_cmd % \ - (self._python_interp, model, self._ps_endpoints, - trainer_id, ep, update_method, self._lr) + (self._python_interp, model, self._ps_endpoints, + trainer_id, ep, update_method, self._lr) if self._use_reduce: tr_cmd += " --use_reduce" @@ -825,9 +848,9 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_v"] = "10" required_envs["GLOG_logtostderr"] = "1" - local_losses\ + local_losses \ = self._run_local(model_file, required_envs, - check_error_log) + check_error_log) if self._nccl2_mode: if self._nccl2_reduce_layer: tr0_losses, tr1_losses = self._run_cluster_nccl2( diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index 55234a8573..a108631df7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import print_function import os @@ -29,14 +30,13 @@ def skip_ci(func): return __func__ -@skip_ci class TestDistCTR2x2(TestDistBase): def _setup_config(self): self._sync_mode = True self._enforce_place = "CPU" def test_dist_ctr(self): - self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) + self.check_with_place("dist_ctr.py", delta=1e-2, check_error_log=False) @skip_ci @@ -54,5 +54,40 @@ class TestDistCTRWithL2Decay2x2(TestDistBase): need_envs=need_envs) +class TestDistCTR2x2_ASYNC(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._hogwild_mode = True + self._enforce_place = "CPU" + + def test_dist_ctr(self): + need_envs = { + "FLAGS_communicator_send_queue_size": "2", + "FLAGS_communicator_max_merge_var_num": "2", + "FLAGS_communicator_max_send_grad_num_before_recv": "2", + } + + self.check_with_place( + "dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs) + + +class TestDistCTR2x2_ASYNC2(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._hogwild_mode = True + self._enforce_place = "CPU" + + def test_dist_ctr(self): + need_envs = { + "FLAGS_communicator_send_queue_size": "2", + "FLAGS_communicator_max_merge_var_num": "2", + "FLAGS_communicator_max_send_grad_num_before_recv": "2", + "FLAGS_communicator_independent_recv_thread": "0" + } + + self.check_with_place( + "dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 9bad641a8c..acefd65b56 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -30,7 +30,6 @@ def skip_ci(func): return __func__ -@skip_ci class TestDistMnist2x2(TestFleetBase): def _setup_config(self): self._sync_mode = False diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py index 30a7ec095e..821974914b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -33,7 +33,7 @@ class TestDistSimnetBowDense2x2(TestDistBase): self.check_with_place( "dist_simnet_bow.py", delta=1e-5, - check_error_log=False, + check_error_log=True, need_envs=need_envs) -- GitLab