From 58f7695ab2022319cc6a9e95348530496ed8d104 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 23 May 2019 15:35:08 +0800 Subject: [PATCH] Async exe support communicator (#17386) Async exe support communicator --- .../details/async_ssa_graph_executor.cc | 66 +++--- .../multi_devices_graph_pass.h | 2 +- paddle/fluid/framework/operator.cc | 2 +- .../operators/distributed/communicator.cc | 108 ++++++++-- .../operators/distributed/communicator.h | 25 +-- .../operators/distributed_ops/recv_op.cc | 4 +- paddle/fluid/pybind/CMakeLists.txt | 24 ++- paddle/fluid/pybind/communicator_py.cc | 47 ++++ paddle/fluid/pybind/communicator_py.h | 27 +++ paddle/fluid/pybind/pybind.cc | 7 + python/paddle/fluid/communicator.py | 88 ++++++++ .../fluid/incubate/fleet/base/fleet_base.py | 26 +-- .../fluid/incubate/fleet/base/role_maker.py | 42 +++- .../incubate/fleet/collective/__init__.py | 9 +- .../distributed_transpiler/__init__.py | 66 +++--- .../fleet/parameter_server/pslib/__init__.py | 19 +- .../incubate/fleet/tests/cluster_train.sh | 33 +++ .../fleet/tests/ctr_dataset_reader.py | 100 +++++++++ .../incubate/fleet/tests/fleet_deep_ctr.py | 204 ++++++++++++++++++ python/paddle/fluid/optimizer.py | 15 +- python/paddle/fluid/tests/CMakeLists.txt | 4 + .../paddle/fluid/tests/test_communicator.py | 32 +++ .../fluid/transpiler/distribute_transpiler.py | 4 +- 23 files changed, 805 insertions(+), 149 deletions(-) create mode 100644 paddle/fluid/pybind/communicator_py.cc create mode 100644 paddle/fluid/pybind/communicator_py.h create mode 100644 python/paddle/fluid/communicator.py create mode 100644 python/paddle/fluid/incubate/fleet/tests/cluster_train.sh create mode 100644 python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py create mode 100644 python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py create mode 100644 python/paddle/fluid/tests/test_communicator.py diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 7f63c07b18f..ce7849cb419 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -51,45 +51,39 @@ void ProcessGraph(std::vector graphs, Scope *scope) { VLOG(3) << "ProcessGraph"; RpcCtxMap send_varname_to_ctx; RpcCtxMap recv_varname_to_ctx; - for (auto i = 0; i < graphs.size(); ++i) { - std::vector nodes_to_delete; - for (auto &node : graphs[i]->Nodes()) { - VLOG(3) << "node name " << node->Name(); - if (node && node->IsOp()) { - if (node->Name() == "send") { - auto send_var_name = node->Op()->Input("X")[0]; - auto send_varnames = boost::get>( - node->Op()->GetNullableAttr("send_varnames")); - auto epmap = boost::get>( - node->Op()->GetNullableAttr("epmap")); - auto height_section = boost::get>( - node->Op()->GetNullableAttr("sections")); - auto trainer_id = - boost::get(node->Op()->GetNullableAttr("trainer_id")); - send_varname_to_ctx[send_var_name] = - operators::distributed::RpcContext(send_var_name, send_varnames, - epmap, height_section, - trainer_id); - VLOG(3) << "find and init an send op: " - << send_varname_to_ctx[send_var_name]; - } else if (node->Name() == "recv") { - auto recv_var_name = node->Op()->Output("Out")[0]; - auto recv_varnames = boost::get>( - node->Op()->GetNullableAttr("recv_varnames")); - auto epmap = boost::get>( - node->Op()->GetNullableAttr("epmap")); - auto trainer_id = - boost::get(node->Op()->GetNullableAttr("trainer_id")); - recv_varname_to_ctx[recv_var_name] = - operators::distributed::RpcContext(recv_var_name, recv_varnames, - epmap, {}, trainer_id); - nodes_to_delete.push_back(node); - VLOG(3) << "find and remove an recv op: " - << recv_varname_to_ctx[recv_var_name]; - } + for (auto &node : graphs[0]->Nodes()) { + VLOG(3) << "node name " << node->Name(); + if (node && node->IsOp()) { + if (node->Name() == "send") { + auto send_var_name = node->Op()->Input("X")[0]; + auto send_varnames = boost::get>( + node->Op()->GetNullableAttr("send_varnames")); + auto epmap = boost::get>( + node->Op()->GetNullableAttr("epmap")); + auto height_section = boost::get>( + node->Op()->GetNullableAttr("sections")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); + send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( + send_var_name, send_varnames, epmap, height_section, trainer_id); + VLOG(3) << "find and init an send op: " + << send_varname_to_ctx[send_var_name]; + } else if (node->Name() == "recv") { + auto recv_var_name = node->Op()->Output("Out")[0]; + auto recv_varnames = boost::get>( + node->Op()->GetNullableAttr("recv_varnames")); + auto epmap = boost::get>( + node->Op()->GetNullableAttr("epmap")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); + recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext( + recv_var_name, recv_varnames, epmap, {}, trainer_id); + VLOG(3) << "find and remove an recv op: " + << recv_varname_to_ctx[recv_var_name]; } } } + // init communicator here if (send_varname_to_ctx.size() > 0) { VLOG(3) << "this is distribute mode, will use communicator"; diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h index 3434d45f142..a377bbf6b7d 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h @@ -130,7 +130,7 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { if (node->Op()->Type() == "recv") { VLOG(1) << "set recv op do_not_run to true"; - node->Op()->SetAttr("do_not_run", true); + node->Op()->SetAttr("do_not_run", 1); node->Op()->Flush(); } else if (node->Name() == "lookup_table" || node->Name() == "nce" || node->Name() == "hierarchical_sigmoid") { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index fa6de326bc1..8d4623468b9 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1142,7 +1142,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( t = &(var->Get().value()); } if (t != nullptr) { - PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized", + PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized", input.first, i); proto::VarType::Type tmp = t->type(); PADDLE_ENFORCE( diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 6e1015d320b..3a185667e7a 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/communicator.h" #include +#include #include // NOLINT #include // NOLINT @@ -50,8 +51,7 @@ inline double GetCurrentUS() { return 1e+6 * time.tv_sec + time.tv_usec; } -std::unique_ptr Communicator::communicator_(nullptr); -std::once_flag Communicator::init_flag_; +std::shared_ptr Communicator::communicator_(nullptr); Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx, @@ -84,11 +84,17 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, } Communicator::~Communicator() { - VLOG(3) << "~Communicator"; + if (FLAGS_v >= 3) { + std::string msg("~Communicator"); + fwrite(msg.c_str(), msg.length(), 1, stdout); + } running_ = false; if (send_thread_) send_thread_->join(); if (recv_thread_) recv_thread_->join(); - VLOG(3) << "~Communicator done"; + if (FLAGS_v >= 3) { + std::string msg("~Communicator done"); + fwrite(msg.c_str(), msg.length(), 1, stdout); + } } void Communicator::SendThread() { @@ -144,7 +150,7 @@ void Communicator::SendThread() { task_futures.emplace_back( send_threadpool_->enqueue(std::move(send_task))); } else { - VLOG(3) << var_name << " queue empty"; + VLOG(4) << var_name << " queue empty"; } } for (auto &task_f : task_futures) { @@ -160,17 +166,19 @@ void Communicator::SendThread() { RecvAll(); } } + VLOG(0) << "communicator stopped, send thread exit"; } void Communicator::RecvAll() { VLOG(3) << "parallel run recv graph"; + if (!running_) return; auto before_send = GetCurrentUS(); std::vector> task_futures; task_futures.reserve(recv_varname_to_ctx_.size()); for (auto &iter : recv_varname_to_ctx_) { auto recv_task = [this, &iter] { auto &var_name = iter.first; - VLOG(3) << "recv var " << var_name; + VLOG(4) << "recv var " << var_name; auto recv_functor = distributed::ParameterRecv(); if (!FLAGS_communicator_fake_rpc) { recv_functor(iter.second, *recv_scope_); @@ -197,6 +205,7 @@ void Communicator::RecvThread() { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } + VLOG(0) << "communicator stopped, recv thread exit"; } void Communicator::Send(const std::string &var_name, @@ -212,17 +221,90 @@ void Communicator::Send(const std::string &var_name, queue->Push(tmp_grad_var); } +void Communicator::Init(const paddle::framework::ProgramDesc &program, + Scope *param_scope) { + using RpcCtxMap = operators::distributed::RpcCtxMap; + VLOG(3) << "ProcessGraph"; + RpcCtxMap send_varname_to_ctx; + RpcCtxMap recv_varname_to_ctx; + for (auto *op : program.Block(0).AllOps()) { + VLOG(3) << "node name " << op->Type(); + if (op->Type() == "send") { + auto send_var_name = op->Input("X")[0]; + auto send_varnames = boost::get>( + op->GetNullableAttr("send_varnames")); + auto epmap = + boost::get>(op->GetNullableAttr("epmap")); + auto height_section = + boost::get>(op->GetNullableAttr("sections")); + auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); + send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( + send_var_name, send_varnames, epmap, height_section, trainer_id); + VLOG(3) << "find and init an send op: " + << send_varname_to_ctx[send_var_name]; + } else if (op->Type() == "recv") { + auto do_not_run = boost::get(op->GetNullableAttr("do_not_run")); + PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!"); + auto recv_var_name = op->Output("Out")[0]; + auto recv_varnames = boost::get>( + op->GetNullableAttr("recv_varnames")); + auto epmap = + boost::get>(op->GetNullableAttr("epmap")); + auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); + recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext( + recv_var_name, recv_varnames, epmap, {}, trainer_id); + } + } + + // init communicator here + if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) { + LOG(WARNING) << "no var need to send and recv!!"; + } + operators::distributed::Communicator::Init(send_varname_to_ctx, + recv_varname_to_ctx, param_scope); +} + Communicator *Communicator::GetInstance() { return communicator_.get(); } +std::shared_ptr Communicator::GetInstantcePtr() { + return communicator_; +} + void Communicator::Start() { - running_ = true; - // start send and recv thread - send_thread_.reset( - new std::thread(std::bind(&Communicator::SendThread, this))); - if (FLAGS_communicator_independent_recv_thread) { - recv_thread_.reset( - new std::thread(std::bind(&Communicator::RecvThread, this))); + VLOG(0) << "Communicator start"; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + VLOG(1) << "start send thread and recv thread"; + running_ = true; + // start send and recv thread + send_thread_.reset( + new std::thread(std::bind(&Communicator::SendThread, this))); + if (FLAGS_communicator_independent_recv_thread) { + recv_thread_.reset( + new std::thread(std::bind(&Communicator::RecvThread, this))); + } + } +} + +void Communicator::Stop() { + VLOG(0) << "Communicator stop"; + running_ = false; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + if (send_thread_) { + VLOG(1) << "stop send thread"; + send_thread_->join(); + send_thread_.reset(nullptr); + } + if (recv_thread_) { + VLOG(1) << "stop recv thread"; + recv_thread_->join(); + recv_thread_.reset(nullptr); + } } + VLOG(0) << "Communicator stop done"; } } // namespace distributed diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 37c39eb1511..17f68fb4f1b 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -165,6 +165,7 @@ class Communicator { ~Communicator(); void Start(); + void Stop(); // send grad void Send(const std::string& var_name, const framework::Scope& scope); @@ -181,8 +182,8 @@ class Communicator { send_varname_to_queue_; RpcCtxMap send_varname_to_ctx_; RpcCtxMap recv_varname_to_ctx_; - std::unique_ptr send_thread_; - std::unique_ptr recv_thread_; + std::unique_ptr send_thread_{nullptr}; + std::unique_ptr recv_thread_{nullptr}; Scope* recv_scope_; // should be global scope std::unique_ptr send_scope_; // an independent scope std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; @@ -193,25 +194,21 @@ class Communicator { public: static void Init(const RpcCtxMap& send_varname_to_ctx, const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) { - InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope); - } - - static Communicator* GetInstance(); - - private: - // Init is called by GetInstance. - static void InitImpl(const RpcCtxMap& send_varname_to_ctx, - const RpcCtxMap& recv_varname_to_ctx, - Scope* recv_scope) { if (communicator_ == nullptr) { communicator_.reset(new Communicator(send_varname_to_ctx, recv_varname_to_ctx, recv_scope)); } } + static void Init(const paddle::framework::ProgramDesc& program, + Scope* param_scope); + + static Communicator* GetInstance(); + + static std::shared_ptr GetInstantcePtr(); + private: - static std::once_flag init_flag_; - static std::unique_ptr communicator_; + static std::shared_ptr communicator_; }; } // namespace distributed diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 8e9846b1fc8..b871859dbb1 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -36,7 +36,7 @@ class RecvOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - bool do_not_run = Attr("do_not_run"); + int do_not_run = Attr("do_not_run"); if (do_not_run) { VLOG(3) << "recv do not run!"; return; @@ -132,7 +132,7 @@ This operator can get variables from server side. "(vector) " "the splited parameter varnames to be recved from pserver") .SetDefault(std::vector{}); - AddAttr("do_not_run", "if recv need to really run").SetDefault(false); + AddAttr("do_not_run", "if recv need to really run").SetDefault(0); } }; diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d709508a6d5..eeee507110c 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -5,7 +5,29 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) endif() -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc fleet_wrapper_py.cc nccl_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc) + +if (WITH_DISTRIBUTE) + list(APPEND PYBIND_DEPS communicator) +endif() + +set(PYBIND_SRCS + pybind.cc + exception.cc + protobuf.cc + const_value.cc + recordio.cc + reader_py.cc + async_executor_py.cc + fleet_wrapper_py.cc + nccl_wrapper_py.cc + data_set_py.cc + imperative.cc + ir.cc + inference_api.cc) + +if (WITH_DISTRIBUTE) + list(APPEND PYBIND_SRCS communicator_py.cc) +endif() if(WITH_PYTHON) if(WITH_AMD_GPU) diff --git a/paddle/fluid/pybind/communicator_py.cc b/paddle/fluid/pybind/communicator_py.cc new file mode 100644 index 00000000000..1d4052358b3 --- /dev/null +++ b/paddle/fluid/pybind/communicator_py.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/pybind/communicator_py.h" + +#include +#include + +#include "paddle/fluid/framework/program_desc.h" +#include "pybind11/pybind11.h" + +#include "paddle/fluid/operators/distributed/communicator.h" + +namespace py = pybind11; + +using paddle::framework::ProgramDesc; +using paddle::operators::distributed::Communicator; +using paddle::framework::Scope; + +namespace paddle { +namespace pybind { + +void BindCommunicator(py::module* m) { + // Communicator is already used by nccl, change to DistCommunicator + py::class_>(*m, + "DistCommunicator") + .def(py::init([](const ProgramDesc& program, Scope* param_scope) { + Communicator::Init(program, param_scope); + return Communicator::GetInstantcePtr(); + })) + .def("stop", &Communicator::Stop) + .def("start", &Communicator::Start); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/communicator_py.h b/paddle/fluid/pybind/communicator_py.h new file mode 100644 index 00000000000..374c74bdafe --- /dev/null +++ b/paddle/fluid/pybind/communicator_py.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace pybind { + +void BindCommunicator(pybind11::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c2b8e8874fc..2555c0e0729 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -77,6 +77,10 @@ limitations under the License. */ #include "paddle/fluid/platform/gpu_info.h" #endif +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/pybind/communicator_py.h" +#endif + #include "pybind11/stl.h" DEFINE_bool(reader_queue_speed_test_mode, false, @@ -1547,6 +1551,9 @@ All parameter, weight, gradient are variables in Paddle. BindNode(&m); BindInferenceApi(&m); BindDataset(&m); +#ifdef PADDLE_WITH_DISTRIBUTE + BindCommunicator(&m); +#endif } } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py new file mode 100644 index 00000000000..7d0db90b6ad --- /dev/null +++ b/python/paddle/fluid/communicator.py @@ -0,0 +1,88 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .executor import global_scope +from . import core +from .framework import Program + +__all__ = ['Communicator'] + + +class Communicator(object): + def __init__(self, program): + """ + Communicator is used for async distribute training in distribute_transpiler mode. + It's a wrapper of a cpp class Communicator and should be used inside fleet API. + + Args: + program(Program): the trainers program after transpile of distribute_transpiler. + It's used by communicator to extract the information to do communication. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + prog = fluid.Program() + comm = fluid.communicator.Communicator(prog) + comm.start() + comm.stop() + """ + # set all recv op to not_run mode + assert isinstance(program, Program) + for op in program.block(0).ops: + if op.type == "recv": + op._set_attr('do_not_run', True) + self.communicator_ = core.DistCommunicator(program.desc, global_scope()) + + def start(self): + """ + Start communicator. Should call before training process. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + prog = fluid.Program() + comm = fluid.communicator.Communicator(prog) + comm.start() + comm.stop() + """ + self.communicator_.start() + + def stop(self): + """ + Stop communicator. Should call after training process. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + prog = fluid.Program() + comm = fluid.communicator.Communicator(prog) + comm.start() + comm.stop() + """ + self.communicator_.stop() diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index f2f72b0f505..0396cb6d5fd 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -15,14 +15,14 @@ from __future__ import print_function import abc - from enum import Enum -from paddle.fluid.optimizer import SGD +import paddle.fluid as fluid from paddle.fluid.executor import Executor +from paddle.fluid.optimizer import SGD -from role_maker import RoleMakerBase from role_maker import MPISymetricRoleMaker +from role_maker import RoleMakerBase from role_maker import UserDefinedRoleMaker @@ -48,7 +48,6 @@ class Fleet(object): __metaclass__ = abc.ABCMeta def __init__(self, mode): - assert isinstance(mode, Mode) self._is_initialized = False self._mode = mode self._optimizer = None @@ -79,9 +78,9 @@ class Fleet(object): Get current total worker number. Returns: - int: worker number + int: worker numbers """ - return len(self._role_maker.get_trainer_endpoints()) + return self._role_maker.worker_num() def is_worker(self): """ @@ -173,21 +172,19 @@ class Fleet(object): end += length return files[start:end] - def init(self, executor, role_maker=None): + def init(self, role_maker=None): """ should be called only once in user's python scripts, init() will initialize RoleMaker which is used for identifying current node's role, e.g. worker, server, etc. Args: - executor(Executor): The executor to run fleet. role_maker(RoleMakerBase): subclass of RoleMakerBase. Returns: None """ - if not isinstance(executor, Executor): - raise ValueError("executor must be an instance of Executor") + self._executor = Executor(fluid.CPUPlace()) if role_maker and not isinstance(role_maker, RoleMakerBase): raise ValueError("role_maker must be an instance of RoleMakerBase") @@ -215,23 +212,20 @@ class Fleet(object): pass @abc.abstractmethod - def run_server(self, ): + def run_server(self): pass @abc.abstractmethod def stop_worker(self): pass - @abc.abstractmethod - def stop(self): - pass - @abc.abstractmethod def distributed_optimizer(self, optimizer, strategy=None): pass @abc.abstractmethod def save_inference_model(self, + executor, dirname, feeded_var_names, target_vars, @@ -240,7 +234,7 @@ class Fleet(object): pass @abc.abstractmethod - def save_persistables(self, dirname, main_program=None): + def save_persistables(self, executor, dirname, main_program=None): pass diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index 0c1c44cc15f..ae32fa039d1 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -61,6 +61,15 @@ class RoleMakerBase(object): """ raise NotImplementedError("Please implement this method in child class") + def worker_num(self): + """ + Get current total worker number. + + Returns: + int: worker number + """ + raise NotImplementedError("Please implement this method in child class") + def worker_index(self): """ Get current worker id. @@ -197,6 +206,9 @@ class MPISymetricRoleMaker(MPIRoleMaker): return self.is_worker() and 0 == self.worker_index() return False + def worker_num(self): + return self._worker_num() + def is_worker(self): """ return whether current process is worker assigned by role maker @@ -293,10 +305,29 @@ class UserDefinedRoleMaker(RoleMakerBase): """ super(UserDefinedRoleMaker, self).__init__() - self._current_id = current_id - self._role = role - self._worker_num = worker_num - self._server_endpoints = server_endpoints + if not isinstance(current_id, int): + raise TypeError("current_id must be as int") + else: + if current_id < 0: + raise ValueError("current_id must be gather or equal 0") + self._current_id = current_id + + if not isinstance(role, Role): + raise TypeError("role must be as Role") + else: + self._role = role + + if not isinstance(worker_num, int): + raise TypeError("worker_num must be as int") + else: + if worker_num < 0: + raise ValueError("worker_num must be gather or equal 0") + self._worker_num = worker_num + + if not isinstance(server_endpoints, list): + raise TypeError("server_endpoints must be as string list") + else: + self._server_endpoints = server_endpoints def is_worker(self): return self._role == Role.WORKER @@ -312,3 +343,6 @@ class UserDefinedRoleMaker(RoleMakerBase): def server_index(self): return self._current_id + + def worker_num(self): + return self._worker_num diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index e381a0d8c71..c63fa2dc2f8 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -47,17 +47,12 @@ class Collective(Fleet): logging.warn( "You should not call 'stop_worker' method for collective mode.") - def stop(self): - """ - stop(): will be called after a user finishes his/her training task. - """ - logging.warn("You should not call 'stop' method for collective mode.") - def distributed_optimizer(self, optimizer, strategy=None): self._optimizer = CollectiveOptimizer(optimizer, strategy) return self._optimizer def save_inference_model(self, + executor, dirname, feeded_var_names=None, target_vars=None, @@ -67,7 +62,7 @@ class Collective(Fleet): self._executor, main_program, None, None, export_for_deployment) - def save_persistables(self, dirname, main_program=None): + def save_persistables(self, executor, dirname, main_program=None): io.save_persistables(self._executor, dirname, main_program, None) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py index b2ed351da8c..3e0d6d48277 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py @@ -13,18 +13,16 @@ # limitations under the License. import os +import paddle.fluid.io as io +from paddle.fluid.communicator import Communicator from paddle.fluid.framework import default_startup_program - from paddle.fluid.optimizer import Optimizer - -import paddle.fluid.io as io - -from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig +from ...base.fleet_base import DistributedOptimizer from ...base.fleet_base import Fleet from ...base.fleet_base import Mode -from ...base.fleet_base import DistributedOptimizer class DistributedTranspiler(Fleet): @@ -34,9 +32,11 @@ class DistributedTranspiler(Fleet): def __init__(self): super(DistributedTranspiler, self).__init__(Mode.TRANSPILER) - self._transpiler = OriginTranspiler() - self._startup_program = None - self._main_program = None + self._transpile_config = None + self._transpiler = None + self.startup_program = None + self.main_program = None + self._communicator = None def init_worker(self): """ @@ -48,10 +48,9 @@ class DistributedTranspiler(Fleet): Returns: None """ - pass - - def run_worker(self, main_programs=None, scopes=None): - pass + if not self._transpile_config.sync_mode: + self._communicator = Communicator(self.main_program) + self._communicator.start() def init_server(self, model_dir=None): """ @@ -65,19 +64,19 @@ class DistributedTranspiler(Fleet): Returns: None """ - if not self._startup_program: + if not self.startup_program: raise ValueError( "startup_program is None, need invoke DistributedOptimizer.minimize first" ) - self._executor.run(self._startup_program) + self._executor.run(self.startup_program) if model_dir: if not os.path.isdir(model_dir): raise ValueError("There is no directory named '%s'", model_dir) io.load_persistables(self._executor, model_dir, - self._startup_program) + self.startup_program) def run_server(self): """ @@ -86,17 +85,14 @@ class DistributedTranspiler(Fleet): Returns: None """ - if not self._main_program: + if not self.main_program: raise ValueError( "main_program is None, need invoke DistributedOptimizer.minimize first" ) - self._executor.run(self._main_program) + self._executor.run(self.main_program) def stop_worker(self): - pass - - def stop(self): """ Close this executor. @@ -106,6 +102,8 @@ class DistributedTranspiler(Fleet): Returns: None """ + if not self._transpile_config.sync_mode: + self._communicator.stop() self._executor.close() def distributed_optimizer(self, optimizer, strategy=None): @@ -129,6 +127,7 @@ class DistributedTranspiler(Fleet): return self._optimizer def save_inference_model(self, + executor, dirname, feeded_var_names, target_vars, @@ -139,10 +138,10 @@ class DistributedTranspiler(Fleet): and then save it and all related parameters to given `dirname` by the `executor`. """ io.save_inference_model(dirname, feeded_var_names, target_vars, - self._executor, main_program, None, None, + executor, main_program, None, None, export_for_deployment) - def save_persistables(self, dirname, main_program=None): + def save_persistables(self, executor, dirname, main_program=None): """ This function filters out all variables with `persistable==True` from the give `main_program` and then saves these variables to the folder `dirname` @@ -153,21 +152,30 @@ class DistributedTranspiler(Fleet): files, set `filename` None; if you would like to save all variables in a single file, use `filename` to specify the file name. """ - io.save_persistables(self._executor, dirname, main_program, None) + io.save_persistables(executor, dirname, main_program, None) def _transpile(self, config): + if not isinstance(config, DistributeTranspilerConfig): + raise ValueError( + "config must be an instance of DistributeTranspilerConfig") + + if not config.sync_mode: + config.runtime_split_send_recv = True + + self._transpile_config = config self._transpiler = OriginTranspiler(config) self._transpiler.transpile( trainer_id=fleet.worker_index(), pservers=fleet.server_endpoints(to_string=True), - trainers=fleet.worker_num()) + trainers=fleet.worker_num(), + sync_mode=config.sync_mode) if self.is_worker(): - self._main_program = self._transpiler.get_trainer_program() - self._startup_program = default_startup_program() + self.main_program = self._transpiler.get_trainer_program() + self.startup_program = default_startup_program() else: - self._main_program, self._startup_program = \ - self._transpiler.get_pserver_programs(self.server_endpoints(self.server_index())) + self.main_program, self.startup_program = \ + self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()]) fleet = DistributedTranspiler() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 9684a087a40..c16906dc9a4 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -33,8 +33,8 @@ class PSLib(Fleet): self._main_programs = [] self._scopes = [] - def init(self, executor, role_maker=None): - super(PSLib, self).init(executor, MPISymetricRoleMaker()) + def init(self, role_maker=None): + super(PSLib, self).init(MPISymetricRoleMaker()) self._fleet_ptr = fluid.core.Fleet() def init_worker(self): @@ -169,23 +169,12 @@ class PSLib(Fleet): self._role_maker._barrier_all() self._role_maker._finalize() - def stop(self): - """ - stop(): will be called after a user finishes his/her training task. Fleet instance will be - destroyed when stop() is called. - """ - self._role_maker._barrier_worker() - if self._role_maker.is_first_worker(): - self._fleet_ptr.stop_server() - self._role_maker._barrier_worker() - self._role_maker._barrier_all() - self._role_maker._finalize() - def distributed_optimizer(self, optimizer, strategy={}): self._optimizer = DownpourOptimizer(optimizer, strategy) return self._optimizer def save_inference_model(self, + executor, dirname, feeded_var_names=None, target_vars=None, @@ -196,7 +185,7 @@ class PSLib(Fleet): """ self._fleet_ptr.save_model(dirname) - def save_persistables(self, dirname, main_program=None, **kwargs): + def save_persistables(self, executor, dirname, main_program=None, **kwargs): """ save presistable parameters, when using fleet, it will save sparse and dense feature diff --git a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh new file mode 100644 index 00000000000..1df6b0618de --- /dev/null +++ b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# start pserver0 +python fleet_deep_ctr.py \ + --role pserver \ + --endpoints 127.0.0.1:7000,127.0.0.1:7001 \ + --current_endpoint 127.0.0.1:7000 \ + --trainers 2 \ + > pserver0.log 2>&1 & + +# start pserver1 +python fleet_deep_ctr.py \ + --role pserver \ + --endpoints 127.0.0.1:7000,127.0.0.1:7001 \ + --current_endpoint 127.0.0.1:7001 \ + --trainers 2 \ + > pserver1.log 2>&1 & + +# start trainer0 +python fleet_deep_ctr.py \ + --role trainer \ + --endpoints 127.0.0.1:7000,127.0.0.1:7001 \ + --trainers 2 \ + --trainer_id 0 \ + > trainer0.log 2>&1 & + +# start trainer1 +python fleet_deep_ctr.py \ + --role trainer \ + --endpoints 127.0.0.1:7000,127.0.0.1:7001 \ + --trainers 2 \ + --trainer_id 1 \ + > trainer1.log 2>&1 & diff --git a/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py b/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py new file mode 100644 index 00000000000..ace4b01144b --- /dev/null +++ b/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py @@ -0,0 +1,100 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging +import tarfile +import os + +import paddle +import paddle.fluid.incubate.data_generator as data_generator + +logging.basicConfig() +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + +DATA_URL = "http://paddle-ctr-data.bj.bcebos.com/avazu_ctr_data.tgz" +DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e" +""" +avazu_ctr_data/train.txt +avazu_ctr_data/infer.txt +avazu_ctr_data/test.txt +avazu_ctr_data/data.meta.txt +""" + + +def download_file(): + file_name = "avazu_ctr_data" + path = paddle.dataset.common.download(DATA_URL, file_name, DATA_MD5) + + dir_name = os.path.dirname(path) + text_file_dir_name = os.path.join(dir_name, file_name) + + if not os.path.exists(text_file_dir_name): + tar = tarfile.open(path, "r:gz") + tar.extractall(dir_name) + return text_file_dir_name + + +def load_dnn_input_record(sent): + return list(map(int, sent.split())) + + +def load_lr_input_record(sent): + res = [] + for _ in [x.split(':') for x in sent.split()]: + res.append(int(_[0])) + return res + + +class DatasetCtrReader(data_generator.MultiSlotDataGenerator): + def generate_sample(self, line): + def iter(): + fs = line.strip().split('\t') + dnn_input = load_dnn_input_record(fs[0]) + lr_input = load_lr_input_record(fs[1]) + click = [int(fs[2])] + yield ("dnn_data", dnn_input), \ + ("lr_data", lr_input), \ + ("click", click) + + return iter + + +def prepare_data(): + """ + load data meta info from path, return (dnn_input_dim, lr_input_dim) + """ + file_dir_name = download_file() + meta_file_path = os.path.join(file_dir_name, 'data.meta.txt') + train_file_path = os.path.join(file_dir_name, 'train.txt') + with open(meta_file_path, "r") as f: + lines = f.readlines() + err_info = "wrong meta format" + assert len(lines) == 2, err_info + assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[ + 1], err_info + res = map(int, [_.split(':')[1] for _ in lines]) + res = list(res) + dnn_input_dim = res[0] + lr_input_dim = res[1] + logger.info('dnn input dim: %d' % dnn_input_dim) + logger.info('lr input dim: %d' % lr_input_dim) + return dnn_input_dim, lr_input_dim, train_file_path + + +if __name__ == "__main__": + pairwise_reader = DatasetCtrReader() + pairwise_reader.run_from_stdin() diff --git a/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py b/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py new file mode 100644 index 00000000000..ab57137e117 --- /dev/null +++ b/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py @@ -0,0 +1,204 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import time + +import paddle.fluid as fluid +import paddle.fluid.incubate.fleet.base.role_maker as role_maker +from paddle.fluid.incubate.fleet.parameter_server.distributed_transpiler import fleet +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig + +import ctr_dataset_reader + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description="PaddlePaddle Fleet ctr") + + # the following arguments is used for distributed train, if is_local == false, then you should set them + parser.add_argument( + '--role', + type=str, + default='pserver', # trainer or pserver + help='The path for model to store (default: models)') + parser.add_argument( + '--endpoints', + type=str, + default='127.0.0.1:6000', + help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001') + parser.add_argument( + '--current_endpoint', + type=str, + default='127.0.0.1:6000', + help='The path for model to store (default: 127.0.0.1:6000)') + parser.add_argument( + '--trainer_id', + type=int, + default=0, + help='The path for model to store (default: models)') + parser.add_argument( + '--trainers', + type=int, + default=1, + help='The num of trainers, (default: 1)') + + return parser.parse_args() + + +def model(): + dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data( + ) + """ network definition """ + dnn_data = fluid.layers.data( + name="dnn_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + lr_data = fluid.layers.data( + name="lr_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + label = fluid.layers.data( + name="click", + shape=[-1, 1], + dtype="int64", + lod_level=0, + append_batch_size=False) + + datas = [dnn_data, lr_data, label] + + # build dnn model + dnn_layer_dims = [128, 64, 32, 1] + dnn_embedding = fluid.layers.embedding( + is_distributed=False, + input=dnn_data, + size=[dnn_input_dim, dnn_layer_dims[0]], + param_attr=fluid.ParamAttr( + name="deep_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + dnn_pool = fluid.layers.sequence_pool(input=dnn_embedding, pool_type="sum") + dnn_out = dnn_pool + for i, dim in enumerate(dnn_layer_dims[1:]): + fc = fluid.layers.fc( + input=dnn_out, + size=dim, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01)), + name='dnn-fc-%d' % i) + dnn_out = fc + + # build lr model + lr_embbding = fluid.layers.embedding( + is_distributed=False, + input=lr_data, + size=[lr_input_dim, 1], + param_attr=fluid.ParamAttr( + name="wide_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") + + merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) + + predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') + acc = fluid.layers.accuracy(input=predict, label=label) + auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict, + label=label) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + return datas, avg_cost, predict, train_file_path + + +def train(args): + datas, avg_cost, predict, train_file_path = model() + + endpoints = args.endpoints.split(",") + if args.role.upper() == "PSERVER": + current_id = endpoints.index(args.current_endpoint) + else: + current_id = 0 + role = role_maker.UserDefinedRoleMaker( + current_id=current_id, + role=role_maker.Role.WORKER + if args.role.upper() == "TRAINER" else role_maker.Role.SERVER, + worker_num=args.trainers, + server_endpoints=endpoints) + + exe = fluid.Executor(fluid.CPUPlace()) + fleet.init(role) + + strategy = DistributeTranspilerConfig() + strategy.sync_mode = False + + optimizer = fluid.optimizer.SGD(learning_rate=0.0001) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(avg_cost) + + if fleet.is_server(): + logger.info("run pserver") + + fleet.init_server() + fleet.run_server() + elif fleet.is_worker(): + logger.info("run trainer") + + fleet.init_worker() + exe.run(fleet.startup_program) + + thread_num = 2 + filelist = [] + for _ in range(thread_num): + filelist.append(train_file_path) + + # config dataset + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_batch_size(128) + dataset.set_use_var(datas) + pipe_command = 'python ctr_dataset_reader.py' + dataset.set_pipe_command(pipe_command) + + dataset.set_filelist(filelist) + dataset.set_thread(thread_num) + + for epoch_id in range(10): + logger.info("epoch {} start".format(epoch_id)) + pass_start = time.time() + dataset.set_filelist(filelist) + exe.train_from_dataset( + program=fleet.main_program, + dataset=dataset, + fetch_list=[avg_cost], + fetch_info=["cost"], + print_period=100, + debug=False) + pass_time = time.time() - pass_start + logger.info("epoch {} finished, pass_time {}".format(epoch_id, + pass_time)) + fleet.stop_worker() + + +if __name__ == "__main__": + args = parse_args() + train(args) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index f494ab92664..a6408322604 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -15,27 +15,26 @@ from __future__ import print_function from collections import defaultdict -from .wrapped_decorator import signature_safe_contextmanager +from functools import reduce -from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program +from paddle.fluid import core from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table +from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program +from paddle.fluid.layers import tensor from . import framework from . import layers from . import unique_name from .backward import append_backward from .clip import append_gradient_clip_ops, error_clip_callback +from .dygraph import base as imperative_base +from .dygraph.learning_rate_scheduler import LearningRateDecay from .framework import program_guard from .initializer import Constant from .layer_helper import LayerHelper from .layers import ops from .regularizer import append_regularization_ops -from .dygraph import base as imperative_base -from .dygraph.learning_rate_scheduler import LearningRateDecay -from paddle.fluid import core -from paddle.fluid.layers import tensor -from functools import reduce -import copy +from .wrapped_decorator import signature_safe_contextmanager __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', diff --git a/python/paddle/fluid/tests/CMakeLists.txt b/python/paddle/fluid/tests/CMakeLists.txt index d24417bbacb..2d81fd43171 100644 --- a/python/paddle/fluid/tests/CMakeLists.txt +++ b/python/paddle/fluid/tests/CMakeLists.txt @@ -1,6 +1,10 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +if(NOT WITH_DISTRIBUTE) + list(REMOVE_ITEM TEST_OPS test_communicator) +endif(NOT WITH_DISTRIBUTE) + foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() diff --git a/python/paddle/fluid/tests/test_communicator.py b/python/paddle/fluid/tests/test_communicator.py new file mode 100644 index 00000000000..24c8c4887ec --- /dev/null +++ b/python/paddle/fluid/tests/test_communicator.py @@ -0,0 +1,32 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid as fluid +from paddle.fluid.communicator import Communicator + + +class TestCommunicator(unittest.TestCase): + def test_communicator_init_and_start(self): + prog = fluid.Program() + comm = Communicator(prog) + comm.start() + comm.stop() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 60f74bb6264..82ac3da2905 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -158,7 +158,7 @@ class DistributeTranspilerConfig(object): wait_port = True # split the send recv var in runtime runtime_split_send_recv = False - sync_mode = None + sync_mode = True class DistributeTranspiler(object): @@ -330,7 +330,7 @@ class DistributeTranspiler(object): return self.trainer_num = trainers - self.sync_mode = self.config.sync_mode if self.config.sync_mode else sync_mode + self.sync_mode = sync_mode self.trainer_id = trainer_id pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints -- GitLab