未验证 提交 7fb817d4 编写于 作者: 1 123malin 提交者: GitHub

add distributed_strategy (#21710)

* add distributed_strategy
上级 ad8a9cb8
......@@ -63,6 +63,43 @@ inline void VSUB(int n, const T *x, const T *y, T *z) {
}
}
void Communicator::SetEnvFlagsDefault() {
env_flags_dict.clear();
env_flags_dict.insert(std::pair<std::string, int>(
"independent_recv_thread", FLAGS_communicator_independent_recv_thread));
env_flags_dict.insert(std::pair<std::string, int>(
"send_queue_size", FLAGS_communicator_send_queue_size));
env_flags_dict.insert(std::pair<std::string, int>(
"min_send_grad_num_before_recv",
FLAGS_communicator_min_send_grad_num_before_recv));
env_flags_dict.insert(std::pair<std::string, int>(
"thread_pool_size", FLAGS_communicator_thread_pool_size));
env_flags_dict.insert(std::pair<std::string, int>(
"send_wait_times", FLAGS_communicator_send_wait_times));
env_flags_dict.insert(std::pair<std::string, int>(
"max_merge_var_num", FLAGS_communicator_max_merge_var_num));
env_flags_dict.insert(
std::pair<std::string, int>("fake_rpc", FLAGS_communicator_fake_rpc));
env_flags_dict.insert(std::pair<std::string, int>(
"merge_sparse_grad", FLAGS_communicator_merge_sparse_grad));
env_flags_dict.insert(std::pair<std::string, int>(
"is_sgd_optimizer", FLAGS_communicator_is_sgd_optimizer));
return;
}
Communicator::Communicator() { SetEnvFlagsDefault(); }
Communicator::Communicator(const std::map<std::string, int> &env_flags) {
SetEnvFlagsDefault();
for (auto &iter : env_flags) {
std::string flag_name = iter.first;
int val_ = iter.second;
env_flags_dict.at(flag_name) = val_;
}
return;
}
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
......@@ -73,25 +110,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
// get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread;
VLOG(0) << "communicator_send_queue_size: "
<< FLAGS_communicator_send_queue_size;
VLOG(0) << "communicator_min_send_grad_num_before_recv: "
<< FLAGS_communicator_min_send_grad_num_before_recv;
VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size;
VLOG(0) << "communicator_send_wait_times: "
<< FLAGS_communicator_send_wait_times;
VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
VLOG(0) << "communicator_merge_sparse_grad: "
<< FLAGS_communicator_merge_sparse_grad;
VLOG(0) << "communicator_is_sgd_optimizer: "
<< FLAGS_communicator_is_sgd_optimizer;
if (send_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be send, will not start send_thread";
} else {
......@@ -99,17 +117,17 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
env_flags_dict["send_queue_size"]);
}
send_threadpool_.reset(
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
new ::ThreadPool(env_flags_dict["thread_pool_size"]));
}
if (recv_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be received, will not start recv_thread";
} else {
recv_threadpool_.reset(
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
new ::ThreadPool(env_flags_dict["thread_pool_size"]));
}
}
......@@ -132,7 +150,7 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
auto merge_add = boost::get<bool>(op->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
merge_add = static_cast<bool>(env_flags_dict["is_sgd_optimizer"]);
}
auto use_send_handler =
boost::get<bool>(op->GetNullableAttr("use_send_handler"));
......@@ -194,10 +212,10 @@ void AsyncCommunicator::SendThread() {
std::vector<std::shared_ptr<Variable>> vars;
int merged_var_num = 0;
int wait_times = 0;
while (merged_var_num < FLAGS_communicator_max_merge_var_num) {
while (merged_var_num < env_flags_dict["max_merge_var_num"]) {
if (var_queue->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= FLAGS_communicator_send_wait_times) {
if (wait_times >= env_flags_dict["send_wait_times"]) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
......@@ -226,7 +244,7 @@ void AsyncCommunicator::SendThread() {
VLOG(4) << "merge " << merged_var_num << " " << var_name
<< " use time " << after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
if (!FLAGS_communicator_fake_rpc) {
if (!env_flags_dict["fake_rpc"]) {
send_functor(ctx, *send_scope_, true, 1);
}
auto after_send = GetCurrentUS();
......@@ -255,7 +273,7 @@ void AsyncCommunicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
int grad_num = grad_num_.load();
if (grad_num > FLAGS_communicator_min_send_grad_num_before_recv) {
if (grad_num > env_flags_dict["min_send_grad_num_before_recv"]) {
VLOG(1) << "current grad num " << grad_num;
RecvAll();
grad_num_.store(0);
......@@ -273,10 +291,10 @@ void AsyncCommunicator::Send(const std::string &var_name,
auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
if (grad_var->IsType<framework::SelectedRows>() &&
!FLAGS_communicator_merge_sparse_grad) {
!env_flags_dict["merge_sparse_grad"]) {
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
if (!env_flags_dict["fake_rpc"]) {
send_functor(ctx, scope, true, 1);
}
} else {
......@@ -289,7 +307,7 @@ void AsyncCommunicator::Send(const std::string &var_name,
}
void AsyncCommunicator::Recv() {
if (FLAGS_communicator_independent_recv_thread) {
if (env_flags_dict["independent_recv_thread"]) {
return;
}
......@@ -313,7 +331,7 @@ void AsyncCommunicator::RecvAll() {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
if (!env_flags_dict["fake_rpc"]) {
recv_functor(iter.second, *recv_scope_);
}
};
......@@ -336,7 +354,7 @@ void AsyncCommunicator::Start() {
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::SendThread, this)));
if (FLAGS_communicator_independent_recv_thread) {
if (env_flags_dict["independent_recv_thread"]) {
recv_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
}
......@@ -396,25 +414,8 @@ void GeoSgdCommunicator::InitImpl(
geo_need_push_nums_ = std::move(geo_need_push_nums);
// get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread;
VLOG(0) << "communicator_send_queue_size: "
<< FLAGS_communicator_send_queue_size;
VLOG(0) << "communicator_min_send_grad_num_before_recv: "
<< FLAGS_communicator_min_send_grad_num_before_recv;
VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size;
VLOG(0) << "communicator_send_wait_times: "
<< FLAGS_communicator_send_wait_times;
VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
VLOG(0) << "communicator_merge_sparse_grad: "
<< FLAGS_communicator_merge_sparse_grad;
VLOG(0) << "Trainer nums: " << trainer_nums_;
VLOG(0) << "geo_sgd_push_before_local_train_nums: " << geo_need_push_nums_;
VLOG(0) << "communicator_merge_sparse_bucket "
<< FLAGS_communicator_merge_sparse_bucket;
// process var info from transpiler
for (auto &iter : vars_info) {
......@@ -461,7 +462,7 @@ void GeoSgdCommunicator::InitImpl(
LOG(WARNING) << "no var need to send and recv!!";
}
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
send_threadpool_.reset(new ::ThreadPool(env_flags_dict["thread_pool_size"]));
need_push_queue_ =
std::make_shared<BlockingQueue<std::shared_ptr<SparseIdsMap>>>(
geo_need_push_nums);
......@@ -570,7 +571,7 @@ void GeoSgdCommunicator::SendThread() {
VLOG(4) << "ids_send_vec_ pushed";
} else if (need_push_queue_->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= FLAGS_communicator_send_wait_times) {
if (wait_times >= env_flags_dict["send_wait_times"]) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
......
......@@ -174,9 +174,12 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator {
public:
Communicator() {}
Communicator();
explicit Communicator(const std::map<std::string, int>& env_flags);
virtual ~Communicator() {}
virtual void SetEnvFlagsDefault();
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
......@@ -221,9 +224,10 @@ class Communicator {
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope) {
const paddle::framework::ProgramDesc& program, Scope* recv_scope,
const std::map<std::string, int>& env_flags) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope);
recv_scope, std::ref(env_flags));
return communicator_.get();
}
......@@ -232,10 +236,12 @@ class Communicator {
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) {
const int& trainers, const int& geo_need_push_nums,
const std::map<std::string, int>& env_flags) {
std::call_once(init_flag_, &Communicator::InitWithTranspilerInfo<T>,
program, training_scope, std::ref(vars_info),
std::ref(trainers), std::ref(geo_need_push_nums));
std::ref(trainers), std::ref(geo_need_push_nums),
std::ref(env_flags));
return communicator_.get();
}
......@@ -253,9 +259,10 @@ class Communicator {
template <typename T>
static void InitWithProgram(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) {
Scope* recv_scope,
const std::map<std::string, int>& env_flags) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_.reset(new T(std::ref(env_flags)));
communicator_->InitImpl(program, recv_scope);
}
}
......@@ -265,9 +272,10 @@ class Communicator {
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) {
const int& trainers, const int& geo_need_push_nums,
const std::map<std::string, int>& env_flags) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_.reset(new T(std::ref(env_flags)));
communicator_->InitImpl(program, training_scope, std::ref(vars_info),
std::ref(trainers), std::ref(geo_need_push_nums));
}
......@@ -277,6 +285,7 @@ class Communicator {
bool running_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
std::unordered_map<std::string, int> env_flags_dict;
};
using SparseIdsMap =
......@@ -284,7 +293,9 @@ using SparseIdsMap =
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() {}
AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, int>& env_flags)
: Communicator(env_flags) {}
~AsyncCommunicator();
void Start() override;
void Stop() override;
......@@ -331,7 +342,9 @@ class AsyncCommunicator : public Communicator {
class GeoSgdCommunicator : public Communicator {
public:
GeoSgdCommunicator() {}
GeoSgdCommunicator() : Communicator() {}
explicit GeoSgdCommunicator(const std::map<std::string, int>& env_flags)
: Communicator(env_flags) {}
~GeoSgdCommunicator();
void InitImpl(
const paddle::framework::ProgramDesc& program, Scope* training_scope,
......
......@@ -356,6 +356,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
auto rpc_get_thread_num = Attr<int>("rpc_get_thread_num");
auto rpc_send_thread_num = Attr<int>("rpc_send_thread_num");
auto rpc_prefetch_thread_num = Attr<int>("rpc_prefetch_thread_num");
request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, dc_sgd));
request_get_handler_.reset(
......@@ -370,21 +374,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
FLAGS_rpc_send_thread_num);
request_send_handler_.get(), rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get(),
FLAGS_rpc_get_thread_num);
request_get_handler_.get(), rpc_get_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get(),
FLAGS_rpc_prefetch_thread_num);
rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(),
FLAGS_rpc_send_thread_num);
request_notify_handler_.get(), rpc_send_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -549,6 +550,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(-1);
AddAttr<int>(kLRDecayBlockId, "BolckID to run lr decay on pserer.")
.SetDefault(-1);
AddAttr<int>("rpc_get_thread_num", "pserver get thread num.").SetDefault(1);
AddAttr<int>("rpc_send_thread_num", "pserver send thread num.")
.SetDefault(1);
AddAttr<int>("rpc_prefetch_thread_num", "pserver prefetch thread num.")
.SetDefault(1);
}
};
......
......@@ -39,19 +39,23 @@ void BindCommunicator(py::module* m) {
// Communicator is already used by nccl, change to DistCommunicator
py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
"DistCommunicator")
.def(py::init([](const ProgramDesc& program, Scope* param_scope) {
.def(py::init([](const ProgramDesc& program, Scope* param_scope,
std::map<std::string, int>& env_flags) {
VLOG(0) << "using communicator";
Communicator::InitInstance<AsyncCommunicator>(program, param_scope);
Communicator::InitInstance<AsyncCommunicator>(program, param_scope,
env_flags);
return Communicator::GetInstantcePtr();
}))
.def(py::init([](
const ProgramDesc& program, Scope* training_scope,
std::map<std::string,
std::map<std::string, std::vector<std::string>>>& vars_info,
int& trainers, int& geo_need_push_nums) {
int& trainers, int& geo_need_push_nums,
std::map<std::string, int>& env_flags) {
VLOG(0) << "using geo sgd communicator";
Communicator::InitInstance<GeoSgdCommunicator>(
program, training_scope, vars_info, trainers, geo_need_push_nums);
program, training_scope, vars_info, trainers, geo_need_push_nums,
env_flags);
return Communicator::GetInstantcePtr();
}))
.def("stop", &Communicator::Stop)
......
......@@ -28,7 +28,8 @@ class Communicator(object):
program,
vars_info=None,
trainers=None,
geo_sgd_need_push_nums=None):
geo_sgd_need_push_nums=None,
env_flags=None):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
......@@ -56,14 +57,19 @@ class Communicator(object):
if op.type == "recv":
op._set_attr('do_not_run', True)
# Todo: Add check
if env_flags is None:
env_flags = {}
if vars_info and trainers and geo_sgd_need_push_nums:
# for geo sgd
self.communicator_ = core.DistCommunicator(
program.desc,
global_scope(), vars_info, trainers, geo_sgd_need_push_nums)
global_scope(), vars_info, trainers, geo_sgd_need_push_nums,
env_flags)
else:
self.communicator_ = core.DistCommunicator(program.desc,
global_scope())
global_scope(),
env_flags)
def start(self):
"""
......
......@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
"""
Convert the fluid program to distributed data-parallelism programs.
"""
from .distributed_strategy import *
import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator
from paddle.fluid.framework import default_main_program
......@@ -26,8 +28,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.optimizer import Optimizer
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler
from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
......@@ -66,15 +67,24 @@ class DistributedTranspiler(Fleet):
from paddle.fluid.transpiler.details.checkport import wait_server_ready
wait_server_ready(fleet.server_endpoints(to_string=False))
if not self._transpile_config.sync_mode:
if self._transpile_config.geo_sgd_mode:
program_config = self._transpile_config.get_program_config()
trainer_communicator_config = self._transpile_config.get_trainer_runtime_config(
)
print(trainer_communicator_config)
need_communicator_flag = False
if isinstance(self._transpile_config, GeoStrategy):
need_communicator_flag = True
self._communicator = Communicator(
self.main_program, self.vars_info,
fleet.worker_num(),
self._transpile_config.geo_sgd_need_push_nums)
else:
self._communicator = Communicator(self.main_program)
fleet.worker_num(), program_config.geo_sgd_need_push_nums,
trainer_communicator_config.get_communicator_flags())
elif isinstance(self._transpile_config, AsyncStrategy):
need_communicator_flag = True
self._communicator = Communicator(
self.main_program,
env_flags=trainer_communicator_config.get_communicator_flags())
if need_communicator_flag:
if not self._communicator.is_running():
self._communicator.start()
else:
......@@ -129,7 +139,8 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if not self._transpile_config.sync_mode:
if isinstance(self._transpile_config, GeoStrategy) or isinstance(
self._transpile_config, AsyncStrategy):
self._communicator.stop()
self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
......@@ -239,36 +250,44 @@ class DistributedTranspiler(Fleet):
io.save_persistables(executor, dirname, main_program, None)
def _transpile(self, config):
if not isinstance(config, DistributeTranspilerConfig):
if isinstance(config, DistributeTranspilerConfig):
self._transpile_config = DistributedStrategy()
self._transpile_config.set_program_config(config)
elif isinstance(config, DistributedStrategy):
self._transpile_config = config
else:
raise TypeError(
"config must be an instance of DistributeTranspilerConfig")
"config must be an instance of DistributeTranspilerConfig or DistributedStrategy"
)
if not config.sync_mode:
config.runtime_split_send_recv = True
program_config = self._transpile_config.get_program_config()
# _origin_program is a deep copy for default_main_program, for inference
self._origin_program = default_main_program().clone(for_test=False)
self._transpile_config = config
if config.geo_sgd_mode:
self._transpiler = GeoSgdTranspiler(config)
if program_config.geo_sgd_mode:
from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler
self._transpiler = GeoSgdTranspiler(program_config)
else:
self._transpiler = OriginTranspiler(config)
self._transpiler = OriginTranspiler(program_config)
self._transpiler._set_server_config(
self._transpile_config.get_server_runtime_config())
if self.is_worker():
self._transpiler.transpile(
trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(),
sync_mode=config.sync_mode)
sync_mode=program_config.sync_mode)
if isinstance(self._role_maker, MPISymetricRoleMaker):
config.wait_port = False
program_config.wait_port = False
self._transpile_config.set_program_config(program_config)
self.main_program = self._transpiler.get_trainer_program(
wait_port=config.wait_port)
wait_port=program_config.wait_port)
self.startup_program = default_startup_program()
if self._transpile_config.geo_sgd_mode:
if program_config.geo_sgd_mode:
self.vars_info = self._transpiler._get_vars_info()
self.startup_program = self._transpiler.trainer_startup_program
else:
......@@ -276,7 +295,7 @@ class DistributedTranspiler(Fleet):
trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(),
sync_mode=config.sync_mode,
sync_mode=program_config.sync_mode,
current_endpoint=self.server_endpoints()[self.server_index()])
self.main_program, self.startup_program = \
self._transpiler.get_pserver_programs(
......@@ -308,14 +327,17 @@ class TranspilerOptimizer(DistributedOptimizer):
super(TranspilerOptimizer, self).__init__(optimizer, strategy)
if strategy:
if not isinstance(strategy, DistributeTranspilerConfig):
if isinstance(strategy, DistributedStrategy):
self._strategy = strategy
elif isinstance(strategy, DistributeTranspilerConfig):
self._strategy = DistributedStrategy()
self._strategy.set_program_config(strategy)
else:
raise TypeError(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig".
"In {} mode, strategy must be an instance of DistributeTranspilerConfig or DistributedStrategy".
format(fleet._mode))
else:
self._strategy = strategy
else:
self._strategy = DistributeTranspilerConfig()
self._strategy = DistributedStrategy()
def backward(self,
loss,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"TrainerRuntimeConfig", "DistributedStrategy", "SyncStrategy",
"AsyncStrategy", "HalfAsyncStrategy", "GeoStrategy", "StrategyFactory"
]
import os
import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
class TrainerRuntimeConfig(object):
def __init__(self):
self.max_merge_var_num = int(
os.getenv("FLAGS_communicator_max_merge_var_num", "20"))
self.send_queue_size = int(
os.getenv("FLAGS_communicator_send_queue_size", "20"))
self.independent_recv_thread = int(
os.getenv("FLAGS_communicator_independent_recv_thread", "1"))
self.min_send_grad_num_before_recv = int(
os.getenv("FLAGS_communicator_min_send_grad_num_before_recv", "20"))
self.thread_pool_size = int(
os.getenv("FLAGS_communicator_thread_pool_size", "5"))
self.send_wait_times = int(
os.getenv("FLAGS_communicator_send_wait_times", "5"))
self.fake_rpc = int(os.getenv("FLAGS_communicator_fake_rpc", "0"))
self.merge_sparse_grad = int(
os.getenv("FLAGS_communicator_merge_sparse_grad", "1"))
self.is_sgd_optimizer = int(
os.getenv("FLAGS_communicator_is_sgd_optimizer", "1"))
# not used
self._rpc_deadline = int(os.getenv("FLAGS_rpc_deadline", "180000"))
self._rpc_retry_times = int(os.getenv("FLAGS_rpc_retry_times", "3"))
def get_communicator_flags(self):
_communicator_flags = dict()
_communicator_flags["max_merge_var_num"] = self.max_merge_var_num
_communicator_flags["send_queue_size"] = self.send_queue_size
_communicator_flags[
"independent_recv_thread"] = self.independent_recv_thread
_communicator_flags[
"min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv
_communicator_flags["thread_pool_size"] = self.thread_pool_size
_communicator_flags["send_wait_times"] = self.send_wait_times
_communicator_flags["fake_rpc"] = self.fake_rpc
_communicator_flags["merge_sparse_grad"] = self.merge_sparse_grad
_communicator_flags["is_sgd_optimizer"] = self.is_sgd_optimizer
return _communicator_flags
def __repr__(self):
_str = "please check that TrainerRuntimeConfig is as expected:\n"
_communicator_flags = self.get_communicator_flags()
for key in _communicator_flags:
_str += "communicator_{}: {}\n".format(key,
_communicator_flags[key])
return _str
class DistributedStrategy(object):
def __init__(self):
self._program_config = DistributeTranspilerConfig()
self._trainer_runtime_config = TrainerRuntimeConfig()
self._server_runtime_config = ServerRuntimeConfig()
self._execute_strategy = fluid.ExecutionStrategy()
self._build_strategy = fluid.BuildStrategy()
num_threads = int(os.getenv("CPU_NUM", "1"))
self._execute_strategy.num_threads = num_threads
if num_threads > 1:
self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
def get_program_config(self):
return self._program_config
def set_program_config(self, config):
if isinstance(config, DistributeTranspilerConfig):
self._program_config = config
elif isinstance(config, dict):
for key in config:
if hasattr(self._program_config, key):
setattr(self._program_config, key, config[key])
else:
raise ValueError(
"DistributeTranspilerConfig doesn't have key: {}".
format(key))
else:
raise TypeError(
"program_config only accept input type: dict or DistributeTranspilerConfig"
)
def get_trainer_runtime_config(self):
return self._trainer_runtime_config
def set_trainer_runtime_config(self, config):
if isinstance(config, TrainerRuntimeConfig):
self._trainer_runtime_config = config
elif isinstance(config, dict):
for key in config:
if hasattr(self._trainer_runtime_config, key):
setattr(self._trainer_runtime_config, key, config[key])
else:
raise ValueError(
"TrainerRuntimeConfig doesn't have key: {}".format(key))
else:
raise TypeError(
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
)
def get_server_runtime_config(self):
return self._server_runtime_config
def set_server_runtime_config(self, config):
if isinstance(config, ServerRuntimeConfig):
self._server_runtime_config = config
elif isinstance(config, dict):
for key in config:
if hasattr(self._server_runtime_config, key):
setattr(self._server_runtime_config, key, config[key])
else:
raise ValueError(
"ServerRuntimeConfig doesn't have key: {}".format(key))
else:
raise TypeError(
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
)
def get_execute_strategy(self):
return self._execute_strategy
def set_execute_strategy(self, config):
if isinstance(config, fluid.ExecutionStrategy):
self._execute_strategy = config
elif isinstance(config, dict):
for key in config:
if hasattr(self._execute_strategy, key):
setattr(self._execute_strategy, key, config[key])
else:
raise ValueError(
"ExecutionStrategy doesn't have key: {}".format(key))
else:
raise TypeError(
"execute_strategy only accept input type: dict or ExecutionStrategy"
)
def get_build_strategy(self):
return self._build_strategy
def set_build_strategy(self, config):
if isinstance(config, fluid.BuildStrategy):
self._build_strategy = config
elif isinstance(config, dict):
for key in config:
if hasattr(self._build_strategy, key):
setattr(self._build_strategy, key, config[key])
else:
raise ValueError(
"BuildStrategy doesn't have key: {}".format(key))
else:
raise TypeError(
"build_strategy only accept input type: dict or BuildStrategy")
class SyncStrategy(DistributedStrategy):
def __init__(self):
super(SyncStrategy, self).__init__()
self._program_config.sync_mode = True
self._program_config.runtime_split_send_recv = False
self._build_strategy.async_mode = False
class AsyncStrategy(DistributedStrategy):
def __init__(self):
super(AsyncStrategy, self).__init__()
self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
class HalfAsyncStrategy(DistributedStrategy):
def __init__(self):
super(HalfAsyncStrategy, self).__init__()
self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = False
self._build_strategy.async_mode = False
class GeoStrategy(DistributedStrategy):
def __init__(self, update_frequency=100):
super(GeoStrategy, self).__init__()
self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True
self._program_config.geo_sgd_mode = True
self._program_config.geo_sgd_need_push_nums = update_frequency
self._build_strategy.async_mode = True
class StrategyFactory(object):
def __init_(self):
pass
@staticmethod
def create_sync_strategy():
return SyncStrategy()
@staticmethod
def create_half_async_strategy():
return HalfAsyncStrategy()
@staticmethod
def create_async_strategy():
return AsyncStrategy()
@staticmethod
def create_geo_strategy(update_frequency=100):
return GeoStrategy(update_frequency)
......@@ -61,6 +61,24 @@ def load_lr_input_record(sent):
return res
class CtrReader(object):
def __init__(self):
pass
def _reader_creator(self, filelist):
def reader():
for file in filelist:
with open(file, 'r') as f:
for line in f:
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
click = [int(fs[2])]
yield [dnn_input] + [lr_input] + [click]
return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
def generate_sample(self, line):
def get_rand(low=0.0, high=1.0):
......
......@@ -21,8 +21,10 @@ import shutil
import tempfile
import time
import paddle
import paddle.fluid as fluid
import os
import numpy as np
import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
......@@ -131,7 +133,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
wn.write(str(program))
def do_training(self, fleet):
def do_pyreader_training(self, fleet):
"""
do training using dataset, using fetch handler to catch variable
Args:
......@@ -146,13 +148,63 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe.run(fleet.startup_program)
thread_num = 2
batch_size = 128
filelist = []
for _ in range(thread_num):
filelist.append(train_file_path)
train_reader = paddle.batch(
paddle.reader.shuffle(
ctr_dataset_reader.CtrReader()._reader_creator(filelist),
buf_size=batch_size * 100),
batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
compiled_prog = fluid.compiler.CompiledProgram(
fleet.main_program).with_data_parallel(
loss_name=self.avg_cost.name,
build_strategy=self.strategy.get_build_strategy(),
exec_strategy=self.strategy.get_execute_strategy())
for epoch_id in range(1):
self.reader.start()
try:
pass_start = time.time()
while True:
loss_val = exe.run(program=compiled_prog,
fetch_list=[self.avg_cost.name])
loss_val = np.mean(loss_val)
print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
loss_val))
pass_time = time.time() - pass_start
except fluid.core.EOFException:
self.reader.reset()
model_dir = tempfile.mkdtemp()
fleet.save_inference_model(
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
fleet.stop_worker()
def do_dataset_training(self, fleet):
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fleet.startup_program)
thread_num = 2
batch_size = 128
filelist = []
for _ in range(thread_num):
filelist.append(train_file_path)
# config dataset
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(128)
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
......@@ -172,11 +224,14 @@ class TestDistCTR2x2(FleetDistRunnerBase):
debug=False)
pass_time = time.time() - pass_start
res_dict = dict()
res_dict['loss'] = self.avg_cost
class FH(fluid.executor.FetchHandler):
def handler(self, fetch_target_vars):
for i in range(len(fetch_target_vars)):
print("{}: \n {}\n".format(self.fetch_target_names[0],
fetch_target_vars[0]))
def handle(self, res_dict):
for key in res_dict:
v = res_dict[key]
print("{}: \n {}\n".format(key, v))
for epoch_id in range(1):
pass_start = time.time()
......@@ -184,7 +239,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe.train_from_dataset(
program=fleet.main_program,
dataset=dataset,
fetch_handler=FH([self.avg_cost.name], period_secs=2),
fetch_handler=FH(var_dict=res_dict, period_secs=2),
debug=False)
pass_time = time.time() - pass_start
......
......@@ -37,6 +37,7 @@ import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
RUN_STEP = 5
LEARNING_RATE = 0.01
......@@ -50,6 +51,19 @@ class FleetDistRunnerBase(object):
do training : exe run program
"""
def generate_strategy(self, args):
self.strategy = None
if args.mode == "async":
self.strategy = StrategyFactory.create_async_strategy()
elif args.mode == "sync":
self.strategy = StrategyFactory.create_sync_strategy()
elif args.mode == "half_async":
self.strategy = StrategyFactory.create_half_async_strategy()
elif args.mode == "geo":
self.strategy = StrategyFactory.create_geo_strategy(
args.geo_sgd_need_push_nums)
return self.strategy
def run_pserver(self, args):
if args.role.upper() != "PSERVER":
raise ValueError("args role must be PSERVER")
......@@ -62,10 +76,7 @@ class FleetDistRunnerBase(object):
fleet.init(role)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = args.sync_mode
strategy.geo_sgd_mode = args.geo_sgd_mode
strategy.geo_sgd_need_push_nums = args.geo_sgd_need_push_nums
strategy = self.generate_strategy(args)
avg_cost = self.net()
......@@ -76,7 +87,28 @@ class FleetDistRunnerBase(object):
fleet.init_server()
fleet.run_server()
def run_trainer(self, args):
def run_dataset_trainer(self, args):
if args.role.upper() != "TRAINER":
raise ValueError("args role must be TRAINER")
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.WORKER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
fleet.init(role)
strategy = self.generate_strategy(args)
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args):
if args.role.upper() != "TRAINER":
raise ValueError("args role must be TRAINER")
......@@ -88,26 +120,33 @@ class FleetDistRunnerBase(object):
fleet.init(role)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = args.sync_mode
strategy.geo_sgd_mode = args.geo_sgd_mode
strategy.geo_sgd_need_push_nums = args.geo_sgd_need_push_nums
strategy = self.generate_strategy(args)
avg_cost = self.net()
self.reader = fluid.io.PyReader(
feed_list=self.feeds,
capacity=64,
iterable=False,
use_double_buffer=False)
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
out = self.do_training(fleet)
out = self.do_pyreader_training(fleet)
def net(self, batch_size=4, lr=0.01):
raise NotImplementedError(
"get_model should be implemented by child classes.")
def do_training(self, fleet):
def do_dataset_training(self, fleet):
raise NotImplementedError(
"do_training should be implemented by child classes.")
"do_dataset_training should be implemented by child classes.")
def do_pyreader_training(self, fleet):
raise NotImplementedError(
"do_pyreader_training should be implemented by child classes.")
class TestFleetBase(unittest.TestCase):
......@@ -120,7 +159,8 @@ class TestFleetBase(unittest.TestCase):
raise NotImplementedError("tests should have _setup_config implemented")
def setUp(self):
self._sync_mode = True
self._mode = "sync"
self._reader = "pyreader"
self._trainers = 2
self._pservers = 2
self._port_set = set()
......@@ -139,7 +179,6 @@ class TestFleetBase(unittest.TestCase):
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self._geo_sgd = False
self._geo_sgd_need_push_nums = 5
self._setup_config()
......@@ -203,21 +242,13 @@ class TestFleetBase(unittest.TestCase):
envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
python_path += " -m coverage run --branch -p"
tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}".format(
python_path, model, self._ps_endpoints, self._trainers)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}".format(
python_path, model, self._ps_endpoints, self._trainers)
tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format(
python_path, model, self._ps_endpoints, self._trainers, self._mode,
self._geo_sgd_need_push_nums, self._reader)
if self._sync_mode:
tr_cmd += " --sync_mode"
ps_cmd += " --sync_mode"
if self._geo_sgd:
tr_cmd += " --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}".format(
self._geo_sgd, self._geo_sgd_need_push_nums)
ps_cmd += " --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}".format(
self._geo_sgd, self._geo_sgd_need_push_nums)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format(
python_path, model, self._ps_endpoints, self._trainers, self._mode,
self._geo_sgd_need_push_nums, self._reader)
# Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
......@@ -301,15 +332,17 @@ def runtime_main(test_class):
parser.add_argument('--endpoints', type=str, required=False, default="")
parser.add_argument('--current_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--sync_mode', action='store_true')
parser.add_argument(
'--geo_sgd_mode', type=bool, required=False, default=False)
parser.add_argument('--mode', type=str, required=False, default='geo')
parser.add_argument(
'--geo_sgd_need_push_nums', type=int, required=False, default=2)
parser.add_argument('--reader', type=str, required=False, default='dataset')
args = parser.parse_args()
model = test_class()
if args.role == "pserver":
model.run_pserver(args)
else:
model.run_trainer(args)
if args.reader == "dataset":
model.run_dataset_trainer(args)
else:
model.run_pyreader_trainer(args)
......@@ -19,9 +19,103 @@ import unittest
from test_dist_fleet_base import TestFleetBase
class TestDistMnist2x2(TestFleetBase):
class TestDistMnistSync2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
self._mode = "sync"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistHalfAsync2x2(TestFleetBase):
def _setup_config(self):
self._mode = "half_async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAsync2x2(TestFleetBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAsyncDataset2x2(TestFleetBase):
def _setup_config(self):
self._mode = "async"
self._reader = "dataset"
def check_with_place(self,
model_file,
......
......@@ -19,15 +19,16 @@ import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler
from test_dist_fleet_base import TestFleetBase
from dist_simnet_bow import train_network
class TestDistGeoCtr_2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
self._geo_sgd = True
self._mode = "geo"
self._reader = "dataset"
self._geo_sgd_need_push_nums = 5
def check_with_place(self,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, StrategyFactory
import os
class TestStrategyFactor(unittest.TestCase):
def test_sync_strategy(self):
os.environ['CPU_NUM'] = "2"
strategy = StrategyFactory.create_sync_strategy()
self.assertEqual(strategy._program_config.sync_mode, True)
self.assertEqual(strategy._program_config.runtime_split_send_recv,
False)
self.assertEqual(strategy._build_strategy.async_mode, False)
self.assertEqual(strategy._execute_strategy.num_threads, 2)
# test set_program_config using DistributeTranspilerConfig()
program_config_class = DistributeTranspilerConfig()
program_config_class.min_block_size = 81920
strategy.set_program_config(program_config_class)
program_config = strategy.get_program_config()
self.assertEqual(program_config.min_block_size, 81920)
# test set_program_config using dict
program_config_dict = dict()
program_config_dict['min_block_size'] = 8192
strategy.set_program_config(program_config_dict)
program_config = strategy.get_program_config()
self.assertEqual(program_config.min_block_size, 8192)
# test set_program_config exception
program_config_dict['unknown'] = None
self.assertRaises(Exception, strategy.set_program_config,
program_config_dict)
program_config_illegal = None
self.assertRaises(Exception, strategy.set_program_config,
program_config_illegal)
def test_geo_strategy(self):
strategy = StrategyFactory.create_geo_strategy(5)
self.assertEqual(strategy._program_config.sync_mode, False)
self.assertEqual(strategy._program_config.runtime_split_send_recv, True)
self.assertEqual(strategy._program_config.geo_sgd_mode, True)
self.assertEqual(strategy._program_config.geo_sgd_need_push_nums, 5)
self.assertEqual(strategy._build_strategy.async_mode, True)
# test set_build_strategy using fluid.BuildStrategy
build_strategy_class = fluid.BuildStrategy()
build_strategy_class.memory_optimize = False
strategy.set_build_strategy(build_strategy_class)
build_strategy = strategy.get_build_strategy()
self.assertEqual(build_strategy.memory_optimize, False)
# test set_build_strategy using dict
build_strategy_dict = dict()
build_strategy_dict['memory_optimize'] = True
strategy.set_build_strategy(build_strategy_dict)
build_strategy = strategy.get_build_strategy()
self.assertEqual(build_strategy.memory_optimize, True)
# test set_build_strategy exception
build_strategy_dict['unknown'] = None
self.assertRaises(Exception, strategy.set_build_strategy,
build_strategy_dict)
build_strategy_illegal = None
self.assertRaises(Exception, strategy.set_build_strategy,
build_strategy_illegal)
def test_async_strategy(self):
strategy = StrategyFactory.create_async_strategy()
self.assertEqual(strategy._program_config.sync_mode, False)
self.assertEqual(strategy._program_config.runtime_split_send_recv, True)
self.assertEqual(strategy._build_strategy.async_mode, True)
# test set_trainer_runtime_config using TrainerRuntimeConfig
trainer_runtime_config_class = TrainerRuntimeConfig()
trainer_runtime_config_class.send_queue_size = 50
print(trainer_runtime_config_class)
strategy.set_trainer_runtime_config(trainer_runtime_config_class)
trainer_runtime_config = strategy.get_trainer_runtime_config()
self.assertEqual(trainer_runtime_config.send_queue_size, 50)
# test set_trainer_runtime_config using dict
trainer_runtime_config_dict = dict()
trainer_runtime_config_dict['send_queue_size'] = 100
strategy.set_trainer_runtime_config(trainer_runtime_config_dict)
trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_communicator_flags = trainer_runtime_config.get_communicator_flags(
)
self.assertIn('send_queue_size', trainer_communicator_flags)
self.assertEqual(trainer_communicator_flags['send_queue_size'], 100)
# test set_trainer_runtime_config exception
trainer_runtime_config_dict['unknown'] = None
self.assertRaises(Exception, strategy.set_trainer_runtime_config,
trainer_runtime_config_dict)
trainer_runtime_config_illegal = None
self.assertRaises(Exception, strategy.set_trainer_runtime_config,
trainer_runtime_config_illegal)
# test set_execute_strategy using fluid.ExecutionStrategy
exec_strategy_class = fluid.ExecutionStrategy()
exec_strategy_class.num_threads = 4
strategy.set_execute_strategy(exec_strategy_class)
exec_strategy = strategy.get_execute_strategy()
self.assertEqual(exec_strategy.num_threads, 4)
# test set_execute_strategy using dict
exec_strategy_dict = dict()
exec_strategy_dict['num_threads'] = 8
strategy.set_execute_strategy(exec_strategy_dict)
exec_strategy = strategy.get_execute_strategy()
self.assertEqual(exec_strategy.num_threads, 8)
# test set_execute_strategy exception
exec_strategy_dict['unknown'] = None
self.assertRaises(Exception, strategy.set_execute_strategy,
exec_strategy_dict)
exec_strategy_illegal = None
self.assertRaises(Exception, strategy.set_execute_strategy,
exec_strategy_illegal)
def test_half_async_strategy(self):
strategy = StrategyFactory.create_half_async_strategy()
self.assertEqual(strategy._program_config.sync_mode, False)
self.assertEqual(strategy._program_config.runtime_split_send_recv,
False)
self.assertEqual(strategy._build_strategy.async_mode, False)
# test set_server_runtime_config using ServerRuntimeConfig
server_runtime_config_class = ServerRuntimeConfig()
server_runtime_config_class._rpc_send_thread_num = 24
strategy.set_server_runtime_config(server_runtime_config_class)
server_runtime_config = strategy.get_server_runtime_config()
self.assertEqual(server_runtime_config._rpc_send_thread_num, 24)
# test set_server_runtime_config using dict
server_runtime_config_dict = dict()
server_runtime_config_dict['_rpc_send_thread_num'] = 20
strategy.set_server_runtime_config(server_runtime_config_dict)
server_runtime_config = strategy.get_server_runtime_config()
self.assertEqual(server_runtime_config._rpc_send_thread_num, 20)
# test set_server_runtime_config exception
server_runtime_config_dict['unknown'] = None
self.assertRaises(Exception, strategy.set_server_runtime_config,
server_runtime_config_dict)
server_runtime_config_illegal = None
self.assertRaises(Exception, strategy.set_server_runtime_config,
server_runtime_config_illegal)
if __name__ == '__main__':
unittest.main()
......@@ -30,6 +30,7 @@ Steps to transpile pserver:
5. add listen_and_serv op
"""
import os
import sys
import math
from functools import reduce
......@@ -177,8 +178,8 @@ class DistributeTranspilerConfig(object):
print_log = False
wait_port = True
# split the send recv var in runtime
_runtime_split_send_recv = False
_sync_mode = True
__runtime_split_send_recv = False
__sync_mode = True
# Geo-sgd algorithm
geo_sgd_mode = False
......@@ -200,31 +201,41 @@ class DistributeTranspilerConfig(object):
@property
def runtime_split_send_recv(self):
return self._runtime_split_send_recv
return self.__runtime_split_send_recv
@runtime_split_send_recv.setter
def runtime_split_send_recv(self, value):
if value is None:
raise ValueError("runtime_split_send_recv can't be None")
if value and self._sync_mode:
if value and self.__sync_mode:
raise ValueError(
"if you want to set runtime_split_send_recv to be true, make ensure config.sync_mode is false at first"
)
self._runtime_split_send_recv = value
self.__runtime_split_send_recv = value
@property
def sync_mode(self):
return self._sync_mode
return self.__sync_mode
@sync_mode.setter
def sync_mode(self, value):
if value is None:
raise ValueError("sync_mode can't be None")
if value and self._runtime_split_send_recv:
if value and self.__runtime_split_send_recv:
raise ValueError(
"if you want to set sync_mode to be true, make ensure config.runtime_split_send_recv is false at first"
)
self._sync_mode = value
self.__sync_mode = value
class ServerRuntimeConfig(object):
def __init__(self):
self._rpc_send_thread_num = int(
os.getenv("FLAGS_rpc_send_thread_num", "12"))
self._rpc_get_thread_num = int(
os.getenv("FLAGS_rpc_get_thread_num", "12"))
self._rpc_prefetch_thread_num = int(
os.getenv("FLAGS_rpc_prefetch_thread_num", "12"))
class DistributeTranspiler(object):
......@@ -295,6 +306,7 @@ class DistributeTranspiler(object):
self.config = config
else:
self.config = DistributeTranspilerConfig()
self._set_server_config()
if self.config.split_method is None:
self.config.split_method = RoundRobin
......@@ -306,6 +318,16 @@ class DistributeTranspiler(object):
assert (self.config.split_method.__bases__[0] == PSDispatcher)
self.counter_var = None
def _set_server_config(self, server_config=None):
if server_config is None:
self.server_config = ServerRuntimeConfig()
elif isinstance(server_config, ServerRuntimeConfig):
self.server_config = server_config
else:
raise TypeError(
"In DistributeTranspiler, server_config must be an instance of ServerRuntimeConfig"
)
def _transpile_nccl2(self,
trainer_id,
trainers,
......@@ -1313,6 +1335,10 @@ class DistributeTranspiler(object):
"grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param,
"lr_decay_block_id": lr_decay_block_id,
"rpc_get_thread_num": self.server_config._rpc_get_thread_num,
"rpc_send_thread_num": self.server_config._rpc_send_thread_num,
"rpc_prefetch_thread_num":
self.server_config._rpc_prefetch_thread_num
}
if self.has_distributed_lookup_table:
......
......@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from .details import wait_server_ready, VarsDistributed
from .details import delete_ops
from ..distribute_lookup_table import find_distributed_lookup_table
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var, ServerRuntimeConfig
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
......@@ -51,6 +51,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
self.config = config
else:
self.config = DistributeTranspilerConfig()
self._set_server_config()
if self.config.split_method is None:
self.config.split_method = RoundRobin
......@@ -241,7 +242,11 @@ class GeoSgdTranspiler(DistributeTranspiler):
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": param_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param
"sparse_grad_to_param": sparse_grad_to_param,
"rpc_get_thread_num": self.server_config._rpc_get_thread_num,
"rpc_send_thread_num": self.server_config._rpc_send_thread_num,
"rpc_prefetch_thread_num":
self.server_config._rpc_prefetch_thread_num
}
# step5 append the listen_and_serv op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册