未验证 提交 fa4e0e82 编写于 作者: T tangwei12 提交者: GitHub

integrated HALF_ASYNC to communicator (#21869) (#22343)

* add half_async in the communicator
* fix DistributedStrategy
上级 fa7bb0c7
...@@ -192,7 +192,7 @@ if(WITH_DISTRIBUTE) ...@@ -192,7 +192,7 @@ if(WITH_DISTRIBUTE)
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -48,7 +48,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -48,7 +48,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap; using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph"; VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx; RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) { for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name(); VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) { if (node && node->IsOp()) {
...@@ -74,30 +74,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -74,30 +74,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
merge_add, use_send_handler); merge_add, use_send_handler);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
} }
} }
} }
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator"; auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
auto *instance = operators::distributed::Communicator::InitInstance< PADDLE_ENFORCE_EQ(initialized, true,
operators::distributed::AsyncCommunicator>(send_varname_to_ctx, platform::errors::InvalidArgument(
recv_varname_to_ctx, scope); "Communicator is not Initialized, you may use "
if (!instance->IsRunning()) instance->Start(); "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
} }
#endif #endif
} }
......
...@@ -179,6 +179,7 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -179,6 +179,7 @@ class HogwildWorker : public CPUWorkerBase {
void CreateThreadScope(const ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program);
std::vector<std::string> op_names_; std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
bool thread_barrier_;
// Scope* thread_scope_; // Scope* thread_scope_;
HogwildWorkerParameter param_; HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_; std::vector<std::string> skip_ops_;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h" #include "paddle/fluid/platform/lodtensor_printer.h"
...@@ -29,6 +30,7 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { ...@@ -29,6 +30,7 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
thread_barrier_ = desc.thread_barrier();
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
...@@ -158,6 +160,12 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -158,6 +160,12 @@ void HogwildWorker::TrainFilesWithProfiler() {
thread_scope_->DropKids(); thread_scope_->DropKids();
timeline.Start(); timeline.Start();
} }
#ifdef PADDLE_WITH_DISTRIBUTE
if (thread_barrier_) {
operators::distributed::Communicator::GetInstance()
->BarrierTriggerDecrement();
}
#endif
} }
void HogwildWorker::TrainFiles() { void HogwildWorker::TrainFiles() {
...@@ -183,6 +191,12 @@ void HogwildWorker::TrainFiles() { ...@@ -183,6 +191,12 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
#ifdef PADDLE_WITH_DISTRIBUTE
if (thread_barrier_) {
operators::distributed::Communicator::GetInstance()
->BarrierTriggerDecrement();
}
#endif
} }
void HogwildWorker::PrintFetchVars() { void HogwildWorker::PrintFetchVars() {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -38,6 +39,14 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -38,6 +39,14 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = readers.size(); thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_; VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_); workers_.resize(thread_num_);
#ifdef PADDLE_WITH_DISTRIBUTE
if (trainer_desc.thread_barrier()) {
operators::distributed::Communicator::GetInstance()->BarrierTriggerReset(
thread_num_);
}
#endif
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
......
...@@ -47,6 +47,7 @@ message TrainerDesc { ...@@ -47,6 +47,7 @@ message TrainerDesc {
// adjust ins weight // adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 20; optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
optional bool no_cvm = 21 [ default = false ]; optional bool no_cvm = 21 [ default = false ];
optional bool thread_barrier = 22;
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -175,117 +175,57 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>; ...@@ -175,117 +175,57 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator { class Communicator {
public: public:
Communicator(); Communicator();
explicit Communicator(const std::map<std::string, int>& env_flags); explicit Communicator(const std::map<std::string, std::string>& envs);
virtual ~Communicator() {} virtual ~Communicator() {}
virtual void SetEnvFlagsDefault();
virtual void Start() = 0; virtual void Start() = 0;
virtual void Stop() = 0; virtual void Stop() = 0;
virtual bool IsRunning() { return running_; } virtual bool IsRunning() { return running_; }
virtual void Send(const std::string& var_name, virtual void Send(const std::vector<std::string>& var_names,
const framework::Scope& scope) = 0; const std::vector<std::string>& var_tables,
virtual void Send(const std::vector<std::string>& sparse_var_names,
const std::vector<std::string>& sparse_var_tables,
const framework::Scope& scope) = 0; const framework::Scope& scope) = 0;
virtual void Recv() = 0; virtual void Recv() = 0;
virtual void Barrier() {}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx, virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) = 0; Scope* recv_scope) {}
virtual void InitImpl(const paddle::framework::ProgramDesc& program, virtual void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) = 0; Scope* recv_scope) = 0;
// for geo-sgd
virtual void InitImpl(
const paddle::framework::ProgramDesc& program, Scope* param_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) = 0;
static Communicator* GetInstance() { return communicator_.get(); } static Communicator* GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() { static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_; return communicator_;
} }
template <typename T>
static Communicator* InitInstance(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>,
send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
return communicator_.get();
}
template <typename T> template <typename T>
static Communicator* InitInstance( static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope, const paddle::framework::ProgramDesc& program, Scope* recv_scope,
const std::map<std::string, int>& env_flags) { const std::map<std::string, std::string>& envs) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program, std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope, std::ref(env_flags)); recv_scope, std::ref(envs));
return communicator_.get(); return communicator_.get();
} }
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums,
const std::map<std::string, int>& env_flags) {
std::call_once(init_flag_, &Communicator::InitWithTranspilerInfo<T>,
program, training_scope, std::ref(vars_info),
std::ref(trainers), std::ref(geo_need_push_nums),
std::ref(env_flags));
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(send_varname_to_ctx, recv_varname_to_ctx,
recv_scope);
}
}
template <typename T> template <typename T>
static void InitWithProgram(const paddle::framework::ProgramDesc& program, static void InitWithProgram(const paddle::framework::ProgramDesc& program,
Scope* recv_scope, Scope* recv_scope,
const std::map<std::string, int>& env_flags) { const std::map<std::string, std::string>& envs) {
if (communicator_.get() == nullptr) { if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(env_flags))); communicator_.reset(new T(std::ref(envs)));
communicator_->InitImpl(program, recv_scope); communicator_->InitImpl(program, recv_scope);
} }
} }
template <typename T>
static void InitWithTranspilerInfo(
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums,
const std::map<std::string, int>& env_flags) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(env_flags)));
communicator_->InitImpl(program, training_scope, std::ref(vars_info),
std::ref(trainers), std::ref(geo_need_push_nums));
}
}
protected: protected:
bool running_ = false; bool running_ = false;
static std::shared_ptr<Communicator> communicator_; static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_; static std::once_flag init_flag_;
std::unordered_map<std::string, int> env_flags_dict; std::unordered_map<std::string, std::string> envs;
}; };
using SparseIdsMap = using SparseIdsMap =
...@@ -294,14 +234,23 @@ using SparseIdsMap = ...@@ -294,14 +234,23 @@ using SparseIdsMap =
class AsyncCommunicator : public Communicator { class AsyncCommunicator : public Communicator {
public: public:
AsyncCommunicator() : Communicator() {} AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, int>& env_flags) explicit AsyncCommunicator(const std::map<std::string, std::string>& envs)
: Communicator(env_flags) {} : Communicator(envs) {
independent_recv_thread_ = static_cast<bool>(
std::stoi(envs.at("communicator_independent_recv_thread")));
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
is_sgd_optimizer_ =
static_cast<bool>(std::stoi(envs.at("communicator_is_sgd_optimizer")));
}
~AsyncCommunicator(); ~AsyncCommunicator();
void Start() override; void Start() override;
void Stop() override; void Stop() override;
void Send(const std::string& var_name,
const framework::Scope& scope) override;
void Recv() override; void Recv() override;
void RecvAll(); void RecvAll();
...@@ -315,15 +264,18 @@ class AsyncCommunicator : public Communicator { ...@@ -315,15 +264,18 @@ class AsyncCommunicator : public Communicator {
void SendThread(); void SendThread();
void RecvThread(); void RecvThread();
void Send(const std::vector<std::string>& sparse_var_names, void Send(const std::vector<std::string>& var_names,
const std::vector<std::string>& sparse_var_tables, const std::vector<std::string>& var_tables,
const framework::Scope& scope) override; const framework::Scope& scope) override;
void InitImpl( private:
const paddle::framework::ProgramDesc& program, Scope* param_scope, int min_send_grad_num_before_recv_;
std::map<std::string, std::map<std::string, std::vector<std::string>>>& int thread_pool_size_;
vars_info, int max_merge_var_num_;
const int& trainers, const int& geo_need_push_nums) override; int send_wait_times_;
int send_queue_size_;
bool independent_recv_thread_;
bool is_sgd_optimizer_;
private: private:
std::unordered_map<std::string, std::unordered_map<std::string,
...@@ -340,30 +292,32 @@ class AsyncCommunicator : public Communicator { ...@@ -340,30 +292,32 @@ class AsyncCommunicator : public Communicator {
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
}; };
class GeoSgdCommunicator : public Communicator { class HalfAsyncCommunicator : public Communicator {
public: public:
GeoSgdCommunicator() : Communicator() {} HalfAsyncCommunicator() {}
explicit GeoSgdCommunicator(const std::map<std::string, int>& env_flags) explicit HalfAsyncCommunicator(const std::map<std::string, std::string>& envs)
: Communicator(env_flags) {} : Communicator(envs) {
~GeoSgdCommunicator(); max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
void InitImpl( send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
const paddle::framework::ProgramDesc& program, Scope* training_scope, thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
std::map<std::string, std::map<std::string, std::vector<std::string>>>& send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
vars_info, }
const int& trainers, const int& geo_need_push_nums) override; ~HalfAsyncCommunicator();
void Start() override; void Start() override;
void Stop() override; void Stop() override;
void Send(const std::string& var_name, void Send(const std::vector<std::string>& var_names,
const framework::Scope& scope) override; const std::vector<std::string>& var_tables,
void Send(const std::vector<std::string>& sparse_var_names,
const std::vector<std::string>& sparse_var_tables,
const framework::Scope& scope) override; const framework::Scope& scope) override;
void Recv() override; void Recv() override;
void Barrier() override;
void BarrierWeakUp();
void BarrierTriggerDecrement() override;
void BarrierTriggerReset(int initial_val) override;
void InitImpl(const RpcCtxMap& send_varname_to_ctx, void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) override; Scope* recv_scope) override;
...@@ -371,6 +325,58 @@ class GeoSgdCommunicator : public Communicator { ...@@ -371,6 +325,58 @@ class GeoSgdCommunicator : public Communicator {
void InitImpl(const paddle::framework::ProgramDesc& program, void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) override; Scope* recv_scope) override;
void ConsumeThread();
private:
int max_merge_var_num_;
int send_wait_times_;
int thread_pool_size_;
int send_queue_size_;
private:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::unique_ptr<std::thread> consume_thread_{nullptr};
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> consume_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
// mutex for Wait for barrier
std::mutex barrier_mutex_;
std::condition_variable barrier_cond_;
std::atomic<int64_t> barrier_trigger_{0};
std::atomic<int64_t> barrier_counter_{0};
};
class GeoSgdCommunicator : public Communicator {
public:
GeoSgdCommunicator() : Communicator() {}
explicit GeoSgdCommunicator(const std::map<std::string, std::string>& envs)
: Communicator(envs) {
geo_need_push_nums_ = std::stoi(envs.at("geo_need_push_nums"));
trainer_nums_ = std::stoi(envs.at("geo_trainer_nums"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
}
~GeoSgdCommunicator();
void Start() override;
void Stop() override;
void Send(const std::vector<std::string>& var_names,
const std::vector<std::string>& var_tables,
const framework::Scope& scope) override;
void Recv() override;
void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) override;
private: private:
void SendThread(); void SendThread();
std::unordered_set<int64_t> SparseIdsMerge( std::unordered_set<int64_t> SparseIdsMerge(
...@@ -379,6 +385,7 @@ class GeoSgdCommunicator : public Communicator { ...@@ -379,6 +385,7 @@ class GeoSgdCommunicator : public Communicator {
void SendUpdateDenseVars(const std::string& var_name, void SendUpdateDenseVars(const std::string& var_name,
const std::string& splited_var_name); const std::string& splited_var_name);
void SendUpdateSparseVars(const std::string& var_name, void SendUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name, const std::string& splited_var_name,
const std::unordered_set<int64_t>& ids_table); const std::unordered_set<int64_t>& ids_table);
...@@ -433,8 +440,11 @@ class GeoSgdCommunicator : public Communicator { ...@@ -433,8 +440,11 @@ class GeoSgdCommunicator : public Communicator {
private: private:
int trainer_nums_ = 1; int trainer_nums_ = 1;
size_t geo_need_push_nums_ = 100; int geo_need_push_nums_ = 100;
bool is_geo_sgd_ = false; int thread_pool_size_;
int send_wait_times_;
private:
int send_var_nums_ = 0; int send_var_nums_ = 0;
RpcCtxMap send_varname_to_ctx_; RpcCtxMap send_varname_to_ctx_;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_GRPC #ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
......
...@@ -36,6 +36,13 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -36,6 +36,13 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
auto is_half_async = Attr<bool>("half_async");
if (is_half_async) {
distributed::Communicator::GetInstance()->Barrier();
return;
}
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
...@@ -76,6 +83,12 @@ the Parameter Server would knew all variables have been sent. ...@@ -76,6 +83,12 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<bool>(
"half_async",
"(bool, default false)"
"half_async=True is for half_async mode, this will send signal "
"to HalfAsyncCommunicator Instance")
.SetDefault(false);
} }
}; };
......
...@@ -48,12 +48,7 @@ class SendOp : public framework::OperatorBase { ...@@ -48,12 +48,7 @@ class SendOp : public framework::OperatorBase {
auto use_send_handler = Attr<bool>("use_send_handler"); auto use_send_handler = Attr<bool>("use_send_handler");
if (send_varnames.size() > 0) { if (send_varnames.size() > 0) {
if (ins.size() > 1) { distributed::Communicator::GetInstance()->Send(ins, send_varnames, scope);
distributed::Communicator::GetInstance()->Send(ins, send_varnames,
scope);
} else {
distributed::Communicator::GetInstance()->Send(ins[0], scope);
}
} else { } else {
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
......
...@@ -27,10 +27,11 @@ limitations under the License. */ ...@@ -27,10 +27,11 @@ limitations under the License. */
namespace py = pybind11; namespace py = pybind11;
using paddle::framework::ProgramDesc; using paddle::framework::ProgramDesc;
using paddle::operators::distributed::Communicator; using paddle::framework::Scope;
using paddle::operators::distributed::AsyncCommunicator; using paddle::operators::distributed::AsyncCommunicator;
using paddle::operators::distributed::Communicator;
using paddle::operators::distributed::GeoSgdCommunicator; using paddle::operators::distributed::GeoSgdCommunicator;
using paddle::framework::Scope; using paddle::operators::distributed::HalfAsyncCommunicator;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -39,29 +40,27 @@ void BindCommunicator(py::module* m) { ...@@ -39,29 +40,27 @@ void BindCommunicator(py::module* m) {
// Communicator is already used by nccl, change to DistCommunicator // Communicator is already used by nccl, change to DistCommunicator
py::class_<Communicator, std::shared_ptr<Communicator>>(*m, py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
"DistCommunicator") "DistCommunicator")
.def(py::init([](const ProgramDesc& program, Scope* param_scope, .def(py::init([](const std::string& mode, const ProgramDesc& program,
std::map<std::string, int>& env_flags) { Scope* param_scope,
VLOG(0) << "using communicator"; std::map<std::string, std::string>& envs) {
Communicator::InitInstance<AsyncCommunicator>(program, param_scope, if (mode == "HALF_ASYNC") {
env_flags); Communicator::InitInstance<HalfAsyncCommunicator>(program,
return Communicator::GetInstantcePtr(); param_scope, envs);
})) } else if (mode == "ASYNC") {
.def(py::init([]( Communicator::InitInstance<AsyncCommunicator>(program, param_scope,
const ProgramDesc& program, Scope* training_scope, envs);
std::map<std::string, } else if (mode == "GEO") {
std::map<std::string, std::vector<std::string>>>& vars_info, Communicator::InitInstance<GeoSgdCommunicator>(program, param_scope,
int& trainers, int& geo_need_push_nums, envs);
std::map<std::string, int>& env_flags) { } else {
VLOG(0) << "using geo sgd communicator"; PADDLE_THROW(platform::errors::InvalidArgument(
Communicator::InitInstance<GeoSgdCommunicator>( "unsuported communicator MODE"));
program, training_scope, vars_info, trainers, geo_need_push_nums, }
env_flags);
return Communicator::GetInstantcePtr(); return Communicator::GetInstantcePtr();
})) }))
.def("stop", &Communicator::Stop) .def("stop", &Communicator::Stop)
.def("start", &Communicator::Start) .def("start", &Communicator::Start)
.def("is_running", &Communicator::IsRunning); .def("is_running", &Communicator::IsRunning);
} }
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -199,17 +199,6 @@ def __bootstrap__(): ...@@ -199,17 +199,6 @@ def __bootstrap__():
read_env_flags.append('worker_update_interval_secs') read_env_flags.append('worker_update_interval_secs')
# env for communicator
read_env_flags.append('communicator_independent_recv_thread')
read_env_flags.append('communicator_send_queue_size')
read_env_flags.append('communicator_min_send_grad_num_before_recv')
read_env_flags.append('communicator_thread_pool_size')
read_env_flags.append('communicator_max_merge_var_num')
read_env_flags.append('communicator_merge_sparse_bucket')
read_env_flags.append('communicator_fake_rpc')
read_env_flags.append('communicator_send_wait_times')
read_env_flags.append('communicator_merge_sparse_grad')
read_env_flags.append('communicator_is_sgd_optimizer')
if core.is_compiled_with_brpc(): if core.is_compiled_with_brpc():
read_env_flags.append('max_body_size') read_env_flags.append('max_body_size')
#set brpc max body size #set brpc max body size
......
...@@ -19,17 +19,13 @@ It's a wrapper of a cpp class Communicator and should be used inside fleet API. ...@@ -19,17 +19,13 @@ It's a wrapper of a cpp class Communicator and should be used inside fleet API.
""" """
from . import core from . import core
from .framework import Program from .framework import Program
from .transpiler.distribute_transpiler import DistributedMode
__all__ = ['Communicator'] __all__ = ['Communicator']
class Communicator(object): class Communicator(object):
def __init__(self, def __init__(self, program, mode, kwargs=None, envs={}):
program,
vars_info=None,
trainers=None,
geo_sgd_need_push_nums=None,
env_flags=None):
""" """
Communicator is used for async distribute training in distribute_transpiler mode. Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API. It's a wrapper of a cpp class Communicator and should be used inside fleet API.
...@@ -56,20 +52,37 @@ class Communicator(object): ...@@ -56,20 +52,37 @@ class Communicator(object):
for op in program.block(0).ops: for op in program.block(0).ops:
if op.type == "recv": if op.type == "recv":
op._set_attr('do_not_run', True) op._set_attr('do_not_run', True)
# Todo: Add check
if env_flags is None: if mode == DistributedMode.GEO:
env_flags = {} push_vars = kwargs["push_vars"]
push_var_names = []
if vars_info and trainers and geo_sgd_need_push_nums:
# for geo sgd for k, vs in push_vars.items():
self.communicator_ = core.DistCommunicator( varnames = "&".join(vs["var_names"])
program.desc, sections = "&".join([str(v) for v in vs["sections"]])
global_scope(), vars_info, trainers, geo_sgd_need_push_nums, endpoints = "&".join(vs["epmap"])
env_flags) is_sparse = "1" if vs["is_sparse"] else "0"
else:
self.communicator_ = core.DistCommunicator(program.desc, push_var_names.append(k)
global_scope(), envs[k] = "#".join([varnames, sections, endpoints, is_sparse])
env_flags)
envs["geo_trainer_nums"] = str(kwargs["trainers"])
envs["geo_need_push_nums"] = str(kwargs["push_nums"])
envs["geo_send_varnames"] = '#'.join(push_var_names)
mode_str = None
if mode == DistributedMode.SYNC:
mode_str = "SYNC"
elif mode == DistributedMode.ASYNC:
mode_str = "ASYNC"
elif mode == DistributedMode.HALF_ASYNC:
mode_str = "HALF_ASYNC"
elif mode == DistributedMode.GEO:
mode_str = "GEO"
self.communicator_ = core.DistCommunicator(mode_str, program.desc,
global_scope(), envs)
def start(self): def start(self):
""" """
......
...@@ -963,6 +963,7 @@ class Executor(object): ...@@ -963,6 +963,7 @@ class Executor(object):
program._pipeline_opt) program._pipeline_opt)
else: else:
trainer = TrainerFactory()._create_trainer(program._fleet_opt) trainer = TrainerFactory()._create_trainer(program._fleet_opt)
trainer._set_thread_barrier(program._is_distributed)
trainer._set_program(program) trainer._set_program(program)
else: else:
if program._pipeline_opt: if program._pipeline_opt:
......
...@@ -17,7 +17,6 @@ import warnings ...@@ -17,7 +17,6 @@ import warnings
""" """
Convert the fluid program to distributed data-parallelism programs. Convert the fluid program to distributed data-parallelism programs.
""" """
from .distributed_strategy import *
import paddle.fluid.io as io import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator from paddle.fluid.communicator import Communicator
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
...@@ -27,8 +26,11 @@ from paddle.fluid.compiler import CompiledProgram ...@@ -27,8 +26,11 @@ from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.optimizer import Optimizer from paddle.fluid.optimizer import Optimizer
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, DistributedStrategy, SyncStrategy, AsyncStrategy, HalfAsyncStrategy, GeoStrategy, StrategyFactory
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig, DistributedMode
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
...@@ -70,25 +72,39 @@ class DistributedTranspiler(Fleet): ...@@ -70,25 +72,39 @@ class DistributedTranspiler(Fleet):
program_config = self._transpile_config.get_program_config() program_config = self._transpile_config.get_program_config()
trainer_communicator_config = self._transpile_config.get_trainer_runtime_config( trainer_communicator_config = self._transpile_config.get_trainer_runtime_config(
) )
if isinstance(self._transpile_config, SyncStrategy):
return
print(trainer_communicator_config) print(trainer_communicator_config)
need_communicator_flag = False
if isinstance(self._transpile_config, GeoStrategy): if isinstance(self._transpile_config, GeoStrategy):
need_communicator_flag = True kwargs = {}
kwargs["push_vars"] = self.vars_info
kwargs["trainers"] = fleet.worker_num()
kwargs["push_nums"] = self._transpile_config.get_program_config(
).geo_sgd_need_push_nums
self._communicator = Communicator( self._communicator = Communicator(
self.main_program, self.vars_info, self.main_program, DistributedMode.GEO, kwargs,
fleet.worker_num(), program_config.geo_sgd_need_push_nums,
trainer_communicator_config.get_communicator_flags()) trainer_communicator_config.get_communicator_flags())
elif isinstance(self._transpile_config, AsyncStrategy): elif isinstance(self._transpile_config, AsyncStrategy):
need_communicator_flag = True
self._communicator = Communicator( self._communicator = Communicator(
self.main_program, self.main_program, DistributedMode.ASYNC, None,
env_flags=trainer_communicator_config.get_communicator_flags()) trainer_communicator_config.get_communicator_flags())
if need_communicator_flag:
if not self._communicator.is_running(): elif isinstance(self._transpile_config, HalfAsyncStrategy):
self._communicator.start() self._communicator = Communicator(
else: self.main_program, DistributedMode.HALF_ASYNC, None,
warnings.warn("communicator has been initialized, skip") trainer_communicator_config.get_communicator_flags())
else:
raise TypeError("Training MODE do not supported")
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
def init_server(self, model_dir=None): def init_server(self, model_dir=None):
""" """
...@@ -139,12 +155,12 @@ class DistributedTranspiler(Fleet): ...@@ -139,12 +155,12 @@ class DistributedTranspiler(Fleet):
Returns: Returns:
None None
""" """
if isinstance(self._transpile_config, GeoStrategy) or isinstance(
self._transpile_config, AsyncStrategy): if not isinstance(self._transpile_config, SyncStrategy):
self._communicator.stop() self._communicator.stop()
self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker): if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize() self._role_maker._finalize()
self._executor.close()
def distributed_optimizer(self, optimizer, strategy=None): def distributed_optimizer(self, optimizer, strategy=None):
""" """
...@@ -250,14 +266,22 @@ class DistributedTranspiler(Fleet): ...@@ -250,14 +266,22 @@ class DistributedTranspiler(Fleet):
io.save_persistables(executor, dirname, main_program, None) io.save_persistables(executor, dirname, main_program, None)
def _transpile(self, config): def _transpile(self, config):
if isinstance(config, DistributeTranspilerConfig): if isinstance(config, DistributedStrategy):
self._transpile_config = DistributedStrategy()
self._transpile_config.set_program_config(config)
elif isinstance(config, DistributedStrategy):
self._transpile_config = config self._transpile_config = config
elif isinstance(config, DistributeTranspilerConfig):
if config.sync_mode:
self._transpile_config = SyncStrategy()
elif config.geo_sgd_mode:
self._transpile_config = GeoStrategy(
config.geo_sgd_need_push_nums)
elif config.runtime_split_send_recv and config.half_async:
self._transpile_config = HalfAsyncStrategy()
else:
self._transpile_config = AsyncStrategy()
self._transpile_config.set_program_config(config)
else: else:
raise TypeError( raise TypeError(
"config must be an instance of DistributeTranspilerConfig or DistributedStrategy" "config must be an instance of DistributeTranspilerConfig, SyncStrategy, HalfAsyncStrategy, AsyncStrategy or GeoStratey."
) )
program_config = self._transpile_config.get_program_config() program_config = self._transpile_config.get_program_config()
...@@ -327,14 +351,12 @@ class TranspilerOptimizer(DistributedOptimizer): ...@@ -327,14 +351,12 @@ class TranspilerOptimizer(DistributedOptimizer):
super(TranspilerOptimizer, self).__init__(optimizer, strategy) super(TranspilerOptimizer, self).__init__(optimizer, strategy)
if strategy: if strategy:
if isinstance(strategy, DistributedStrategy): if isinstance(strategy, DistributeTranspilerConfig) or isinstance(
strategy, DistributedStrategy):
self._strategy = strategy self._strategy = strategy
elif isinstance(strategy, DistributeTranspilerConfig):
self._strategy = DistributedStrategy()
self._strategy.set_program_config(strategy)
else: else:
raise TypeError( raise TypeError(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig or DistributedStrategy". "In {} mode, strategy must be an instance of DistributeTranspilerConfig, SyncStrategy, HalfAsyncStrategy, AsyncStrategy, or GeoStrategy".
format(fleet._mode)) format(fleet._mode))
else: else:
self._strategy = DistributedStrategy() self._strategy = DistributedStrategy()
......
...@@ -24,49 +24,51 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo ...@@ -24,49 +24,51 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
class TrainerRuntimeConfig(object): class TrainerRuntimeConfig(object):
def __init__(self): def __init__(self):
self.max_merge_var_num = int( self.max_merge_var_num = os.getenv(
os.getenv("FLAGS_communicator_max_merge_var_num", "20")) "FLAGS_communicator_max_merge_var_num", "20")
self.send_queue_size = int( self.send_queue_size = os.getenv("FLAGS_communicator_send_queue_size",
os.getenv("FLAGS_communicator_send_queue_size", "20")) "20")
self.independent_recv_thread = int( self.independent_recv_thread = os.getenv(
os.getenv("FLAGS_communicator_independent_recv_thread", "1")) "FLAGS_communicator_independent_recv_thread", "1")
self.min_send_grad_num_before_recv = int( self.min_send_grad_num_before_recv = os.getenv(
os.getenv("FLAGS_communicator_min_send_grad_num_before_recv", "20")) "FLAGS_communicator_min_send_grad_num_before_recv", "20")
self.thread_pool_size = int( self.thread_pool_size = os.getenv("FLAGS_communicator_thread_pool_size",
os.getenv("FLAGS_communicator_thread_pool_size", "5")) "5")
self.send_wait_times = int( self.send_wait_times = os.getenv("FLAGS_communicator_send_wait_times",
os.getenv("FLAGS_communicator_send_wait_times", "5")) "5")
self.fake_rpc = int(os.getenv("FLAGS_communicator_fake_rpc", "0")) self.fake_rpc = os.getenv("FLAGS_communicator_fake_rpc", "0")
self.merge_sparse_grad = int( self.merge_sparse_grad = os.getenv(
os.getenv("FLAGS_communicator_merge_sparse_grad", "1")) "FLAGS_communicator_merge_sparse_grad", "1")
self.is_sgd_optimizer = int( self.is_sgd_optimizer = os.getenv("FLAGS_communicator_is_sgd_optimizer",
os.getenv("FLAGS_communicator_is_sgd_optimizer", "1")) "1")
# not used # not used
self._rpc_deadline = int(os.getenv("FLAGS_rpc_deadline", "180000")) self._rpc_deadline = os.getenv("FLAGS_rpc_deadline", "180000")
self._rpc_retry_times = int(os.getenv("FLAGS_rpc_retry_times", "3")) self._rpc_retry_times = os.getenv("FLAGS_rpc_retry_times", "3")
def get_communicator_flags(self): def get_communicator_flags(self):
_communicator_flags = dict() _communicator_flags = dict()
_communicator_flags["max_merge_var_num"] = self.max_merge_var_num
_communicator_flags["send_queue_size"] = self.send_queue_size
_communicator_flags[ _communicator_flags[
"independent_recv_thread"] = self.independent_recv_thread "communicator_max_merge_var_num"] = self.max_merge_var_num
_communicator_flags[ _communicator_flags[
"min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv "communicator_send_queue_size"] = self.send_queue_size
_communicator_flags["thread_pool_size"] = self.thread_pool_size _communicator_flags[
_communicator_flags["send_wait_times"] = self.send_wait_times "communicator_independent_recv_thread"] = self.independent_recv_thread
_communicator_flags["fake_rpc"] = self.fake_rpc _communicator_flags[
_communicator_flags["merge_sparse_grad"] = self.merge_sparse_grad "communicator_min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv
_communicator_flags["is_sgd_optimizer"] = self.is_sgd_optimizer _communicator_flags[
"communicator_thread_pool_size"] = self.thread_pool_size
_communicator_flags[
"communicator_send_wait_times"] = self.send_wait_times
_communicator_flags[
"communicator_is_sgd_optimizer"] = self.is_sgd_optimizer
return _communicator_flags return _communicator_flags
def __repr__(self): def __repr__(self):
_str = "please check that TrainerRuntimeConfig is as expected:\n" _str = "please check that TrainerRuntimeConfig is as expected:\n"
_communicator_flags = self.get_communicator_flags() _communicator_flags = self.get_communicator_flags()
for key in _communicator_flags: for key in _communicator_flags:
_str += "communicator_{}: {}\n".format(key, _str += "{}: {}\n".format(key, _communicator_flags[key])
_communicator_flags[key])
return _str return _str
...@@ -193,8 +195,9 @@ class HalfAsyncStrategy(DistributedStrategy): ...@@ -193,8 +195,9 @@ class HalfAsyncStrategy(DistributedStrategy):
def __init__(self): def __init__(self):
super(HalfAsyncStrategy, self).__init__() super(HalfAsyncStrategy, self).__init__()
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = False self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = False self._build_strategy.async_mode = True
self._program_config.half_async = True
class GeoStrategy(DistributedStrategy): class GeoStrategy(DistributedStrategy):
...@@ -202,9 +205,9 @@ class GeoStrategy(DistributedStrategy): ...@@ -202,9 +205,9 @@ class GeoStrategy(DistributedStrategy):
super(GeoStrategy, self).__init__() super(GeoStrategy, self).__init__()
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
self._program_config.geo_sgd_mode = True self._program_config.geo_sgd_mode = True
self._program_config.geo_sgd_need_push_nums = update_frequency self._program_config.geo_sgd_need_push_nums = update_frequency
self._build_strategy.async_mode = True
class StrategyFactory(object): class StrategyFactory(object):
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_communicator)
endif(NOT WITH_DISTRIBUTE)
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
endforeach() endforeach()
......
...@@ -20,6 +20,10 @@ list(APPEND MIXED_DIST_TEST_OPS test_transpiler_ops) ...@@ -20,6 +20,10 @@ list(APPEND MIXED_DIST_TEST_OPS test_transpiler_ops)
list(APPEND MIXED_DIST_TEST_OPS test_lookup_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_lookup_remote_table_op)
list(APPEND MIXED_DIST_TEST_OPS test_launch) list(APPEND MIXED_DIST_TEST_OPS test_launch)
list(APPEND MIXED_DIST_TEST_OPS test_launch_ps) list(APPEND MIXED_DIST_TEST_OPS test_launch_ps)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_async)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_geo)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_half_async)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP}) list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach() endforeach()
...@@ -268,6 +272,9 @@ if(WITH_DISTRIBUTE) ...@@ -268,6 +272,9 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS}) py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS})
py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS}) py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS})
py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS}) py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS})
py_test_modules(test_communicator_async MODULES test_communicator_async ENVS ${dist_ENVS})
py_test_modules(test_communicator_geo MODULES test_communicator_geo ENVS ${dist_ENVS})
py_test_modules(test_communicator_half_async MODULES test_communicator_half_async ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1)
if(WITH_DGC) if(WITH_DGC)
# if with dgc, test all dgc tests. # if with dgc, test all dgc tests.
# NOTE. dist dgc tests is already in DIST_TEST_OPS # NOTE. dist dgc tests is already in DIST_TEST_OPS
......
...@@ -110,8 +110,10 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -110,8 +110,10 @@ class TestDistCTR2x2(FleetDistRunnerBase):
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
acc = fluid.layers.accuracy(input=predict, label=label) acc = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict, auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
label=label) label=label)
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -242,11 +244,13 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -242,11 +244,13 @@ class TestDistCTR2x2(FleetDistRunnerBase):
debug=False) debug=False)
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
model_dir = tempfile.mkdtemp() if os.getenv("SAVE_MODEL") == "1":
fleet.save_inference_model( model_dir = tempfile.mkdtemp()
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) fleet.save_inference_model(exe, model_dir,
self.check_model_right(model_dir) [feed.name for feed in self.feeds],
shutil.rmtree(model_dir) self.avg_cost)
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
fleet.stop_worker() fleet.stop_worker()
......
...@@ -16,7 +16,10 @@ from __future__ import print_function ...@@ -16,7 +16,10 @@ from __future__ import print_function
import unittest import unittest
import time import time
import threading
import numpy
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator from paddle.fluid.communicator import Communicator
...@@ -35,7 +38,7 @@ class TestCommunicator(unittest.TestCase): ...@@ -35,7 +38,7 @@ class TestCommunicator(unittest.TestCase):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
return avg_cost return avg_cost
def test_communicator_init_and_start(self): def test_communicator_async(self):
role = role_maker.UserDefinedRoleMaker( role = role_maker.UserDefinedRoleMaker(
current_id=0, current_id=0,
role=role_maker.Role.WORKER, role=role_maker.Role.WORKER,
...@@ -48,23 +51,15 @@ class TestCommunicator(unittest.TestCase): ...@@ -48,23 +51,15 @@ class TestCommunicator(unittest.TestCase):
optimizer = fluid.optimizer.SGD(0.01) optimizer = fluid.optimizer.SGD(0.01)
strategy = DistributeTranspilerConfig() strategy = DistributeTranspilerConfig()
strategy.sync_mode = True strategy.sync_mode = False
strategy.runtime_split_send_recv = True
strategy.wait_port = False strategy.wait_port = False
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
comm = Communicator(fleet.main_program) fleet.init_worker()
comm.start()
time.sleep(10) time.sleep(10)
comm.stop() fleet.stop_worker()
class TestCommunicator2(unittest.TestCase):
def test_communicator_init_and_start(self):
prog = fluid.Program()
comm = Communicator(prog)
comm.start()
comm.stop()
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import time
import threading
import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
class TestCommunicator(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def test_communicator_geo(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(0.01)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = False
strategy.runtime_split_send_recv = True
strategy.geo_sgd_mode = True
strategy.wait_port = False
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_worker()
time.sleep(10)
fleet.stop_worker()
# class TestCommunicatorGEO(unittest.TestCase):
# def test_communicator_init_and_start(self):
# prog = fluid.Program()
# envs = {}
# envs["communicator_thread_pool_size"] = "5"
# envs["communicator_send_wait_times"] = "5"
# kwargs = {}
# kwargs["push_vars"] = {}
# kwargs["trainers"] = 10
# kwargs["push_nums"] = 10
# comm = Communicator(prog, DistributedMode.GEO, kwargs, envs)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import time
import threading
import subprocess
import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost, x, y
def fake_reader(self):
def reader():
for i in range(10000):
x = numpy.random.random((1, 13)).astype('float32')
y = numpy.random.randint(0, 2, (1, 1)).astype('int64')
yield x, y
return reader
def run_pserver(self, role, strategy):
fleet.init(role)
avg_cost, x, y = self.net()
optimizer = fluid.optimizer.SGD(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_server()
fleet.run_server()
def run_trainer(self, role, strategy):
place = fluid.core.CPUPlace()
exe = fluid.Executor(place)
fleet.init(role)
avg_cost, x, y = self.net()
optimizer = fluid.optimizer.SGD(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
exe.run(fleet.startup_program)
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
for batch_id, data in enumerate(train_reader()):
exe.run(fleet.main_program, feed=feeder.feed(data), fetch_list=[])
fleet.stop_worker()
def run_ut(self):
strategy = DistributeTranspilerConfig()
strategy.sync_mode = False
strategy.runtime_split_send_recv = True
strategy.half_async = True
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER
if training_role == "TRAINER" else role_maker.Role.SERVER,
worker_num=2,
server_endpoints=["127.0.0.1:6002"])
if training_role == "TRAINER":
self.run_trainer(role, strategy)
else:
self.run_pserver(role, strategy)
def test_communicator(self):
run_server_cmd = """
from __future__ import print_function
import sys
import os
import time
import threading
import subprocess
import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator
from paddle.fluid.communicator import DistributedMode
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from test_communicator_half_async import TestCommunicatorHalfAsyncEnd2End
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
class RunServer(TestCommunicatorHalfAsyncEnd2End):
def runTest(self):
pass
os.environ["TRAINING_ROLE"] = "PSERVER"
half_run_server = RunServer()
half_run_server.run_ut()
"""
server_file = "run_server_for_communicator_haflaysnc.py"
with open(server_file, "w") as wb:
wb.write(run_server_cmd)
os.environ["TRAINING_ROLE"] = "PSERVER"
_python = sys.executable
ps_cmd = "{} {}".format(_python, server_file)
ps_proc = subprocess.Popen(
ps_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["FLAGS_communicator_send_queue_size"] = "1"
os.environ["FLAGS_communicator_max_merge_var_num"] = "1"
self.run_ut()
ps_proc.kill()
if os.path.exists(server_file):
os.remove(server_file)
# class TestCommunicatorHalfAsync2(unittest.TestCase):
# def test_communicator_init_and_start(self):
# prog = fluid.Program()
# envs = {}
# envs["communicator_send_queue_size"] = "12"
# envs["communicator_max_merge_var_num"] = "12"
# envs["communicator_thread_pool_size"] = "5"
# envs["communicator_send_wait_times"] = "5"
# comm = Communicator(prog, DistributedMode.HALF_ASYNC, None, envs)
# comm.start()
# time.sleep(10)
# comm.stop()
if __name__ == '__main__':
unittest.main()
...@@ -32,7 +32,6 @@ class TestDistCTR2x2(TestDistBase): ...@@ -32,7 +32,6 @@ class TestDistCTR2x2(TestDistBase):
"dist_ctr.py", delta=1e-2, check_error_log=True, log_name=flag_name) "dist_ctr.py", delta=1e-2, check_error_log=True, log_name=flag_name)
@unittest.skip(reason="Skip unstable ci")
class TestDistCTRWithL2Decay2x2(TestDistBase): class TestDistCTRWithL2Decay2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
...@@ -48,6 +47,7 @@ class TestDistCTRWithL2Decay2x2(TestDistBase): ...@@ -48,6 +47,7 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
log_name=flag_name) log_name=flag_name)
@unittest.skip(reason="Skip unstable ci")
class TestDistCTR2x2_ASYNC(TestDistBase): class TestDistCTR2x2_ASYNC(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
...@@ -69,6 +69,7 @@ class TestDistCTR2x2_ASYNC(TestDistBase): ...@@ -69,6 +69,7 @@ class TestDistCTR2x2_ASYNC(TestDistBase):
log_name=flag_name) log_name=flag_name)
@unittest.skip(reason="Skip unstable ci")
class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase): class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
...@@ -91,6 +92,7 @@ class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase): ...@@ -91,6 +92,7 @@ class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase):
log_name=flag_name) log_name=flag_name)
@unittest.skip(reason="Skip unstable ci")
class TestDistCTR2x2_ASYNC2(TestDistBase): class TestDistCTR2x2_ASYNC2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
......
...@@ -53,7 +53,22 @@ class FleetDistRunnerBase(object): ...@@ -53,7 +53,22 @@ class FleetDistRunnerBase(object):
do training : exe run program do training : exe run program
""" """
def generate_strategy(self, args): def build_role(self, args):
if args.role.upper() == "PSERVER":
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.SERVER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
else:
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.WORKER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
return role
def build_strategy(self, args):
self.strategy = None self.strategy = None
if args.mode == "async": if args.mode == "async":
self.strategy = StrategyFactory.create_async_strategy() self.strategy = StrategyFactory.create_async_strategy()
...@@ -66,22 +81,7 @@ class FleetDistRunnerBase(object): ...@@ -66,22 +81,7 @@ class FleetDistRunnerBase(object):
args.geo_sgd_need_push_nums) args.geo_sgd_need_push_nums)
return self.strategy return self.strategy
def run_pserver(self, args): def build_optimizer(self, avg_cost, strategy):
if args.role.upper() != "PSERVER":
raise ValueError("args role must be PSERVER")
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.SERVER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
fleet.init(role)
strategy = self.generate_strategy(args)
avg_cost = self.net()
use_grad_clip = int(os.getenv('GRAD_CLIP', 0)) use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
if use_grad_clip: if use_grad_clip:
# 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm # 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
...@@ -99,70 +99,33 @@ class FleetDistRunnerBase(object): ...@@ -99,70 +99,33 @@ class FleetDistRunnerBase(object):
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
def run_pserver(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net()
self.build_optimizer(avg_cost, strategy)
fleet.init_server() fleet.init_server()
fleet.run_server() fleet.run_server()
def run_dataset_trainer(self, args): def run_dataset_trainer(self, args):
if args.role.upper() != "TRAINER": fleet.init(self.build_role(args))
raise ValueError("args role must be TRAINER") strategy = self.build_strategy(args)
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.WORKER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
fleet.init(role)
strategy = self.generate_strategy(args)
avg_cost = self.net() avg_cost = self.net()
self.build_optimizer(avg_cost, strategy)
use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
if use_grad_clip:
# 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
if use_grad_clip == 1:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByValue(2.0))
elif use_grad_clip == 2:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByNorm(2.0))
elif use_grad_clip == 3:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(2.0))
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
out = self.do_dataset_training(fleet) out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args): def run_pyreader_trainer(self, args):
if args.role.upper() != "TRAINER": fleet.init(self.build_role(args))
raise ValueError("args role must be TRAINER") strategy = self.build_strategy(args)
role = role_maker.UserDefinedRoleMaker(
current_id=args.current_id,
role=role_maker.Role.WORKER,
worker_num=args.trainers,
server_endpoints=args.endpoints.split(","))
fleet.init(role)
strategy = self.generate_strategy(args)
avg_cost = self.net() avg_cost = self.net()
self.reader = fluid.io.PyReader( self.reader = fluid.io.PyReader(
feed_list=self.feeds, feed_list=self.feeds,
capacity=64, capacity=64,
iterable=False, iterable=False,
use_double_buffer=False) use_double_buffer=False)
optimizer = fluid.optimizer.SGD(LEARNING_RATE) self.build_optimizer(avg_cost, strategy)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
out = self.do_pyreader_training(fleet) out = self.do_pyreader_training(fleet)
def net(self, batch_size=4, lr=0.01): def net(self, batch_size=4, lr=0.01):
...@@ -263,7 +226,7 @@ class TestFleetBase(unittest.TestCase): ...@@ -263,7 +226,7 @@ class TestFleetBase(unittest.TestCase):
return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe
def _run_cluster(self, model, envs): def _run_cluster(self, model, envs):
env = {'CPU_NUM': '1', 'GRAD_CLIP': str(self._grad_clip_mode)} env = {'GRAD_CLIP': str(self._grad_clip_mode)}
env.update(envs) env.update(envs)
python_path = self._python_interp python_path = self._python_interp
...@@ -307,29 +270,6 @@ class TestFleetBase(unittest.TestCase): ...@@ -307,29 +270,6 @@ class TestFleetBase(unittest.TestCase):
ps0.terminate() ps0.terminate()
ps1.terminate() ps1.terminate()
'''
with open("/tmp/tr0_out.log", "wb+") as wn:
wn.write(tr0_out)
with open("/tmp/tr1_out.log", "wb+") as wn:
wn.write(tr1_out)
# print server log
'''
# print server log
'''
with open("/tmp/ps0_err.log", "r") as fn:
sys.stderr.write("ps0 stderr: %s\n" % fn.read())
with open("/tmp/ps1_err.log", "r") as fn:
sys.stderr.write("ps1 stderr: %s\n" % fn.read())
'''
# print log
'''
with open("/tmp/tr0_err.log", "r") as fn:
sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
with open("/tmp/tr1_err.log", "r") as fn:
sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
'''
return 0, 0 return 0, 0
......
...@@ -50,9 +50,9 @@ class TestDistMnistSync2x2(TestFleetBase): ...@@ -50,9 +50,9 @@ class TestDistMnistSync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistHalfAsync2x2(TestFleetBase): class TestDistMnistAsync2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "half_async" self._mode = "async"
self._reader = "pyreader" self._reader = "pyreader"
def check_with_place(self, def check_with_place(self,
...@@ -81,10 +81,10 @@ class TestDistMnistHalfAsync2x2(TestFleetBase): ...@@ -81,10 +81,10 @@ class TestDistMnistHalfAsync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAsync2x2(TestFleetBase): class TestDistMnistAsyncDataset2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "async" self._mode = "async"
self._reader = "pyreader" self._reader = "dataset"
def check_with_place(self, def check_with_place(self,
model_file, model_file,
...@@ -96,7 +96,8 @@ class TestDistMnistAsync2x2(TestFleetBase): ...@@ -96,7 +96,8 @@ class TestDistMnistAsync2x2(TestFleetBase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "" "http_proxy": "",
"SAVE_MODEL": "1"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
...@@ -112,10 +113,10 @@ class TestDistMnistAsync2x2(TestFleetBase): ...@@ -112,10 +113,10 @@ class TestDistMnistAsync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAsyncDataset2x2(TestFleetBase): class TestDistCtrHalfAsync2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "async" self._mode = "half_async"
self._reader = "dataset" self._reader = "pyreader"
def check_with_place(self, def check_with_place(self,
model_file, model_file,
...@@ -126,8 +127,12 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase): ...@@ -126,8 +127,12 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase):
"PATH": os.getenv("PATH", ""), "PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"http_proxy": "" "http_proxy": "",
"FLAGS_communicator_send_queue_size": "1",
"FLAGS_communicator_max_merge_var_num": "1",
"CPU_NUM": "1",
"SAVE_MODEL": "0"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
......
...@@ -102,8 +102,10 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -102,8 +102,10 @@ class TestStrategyFactor(unittest.TestCase):
trainer_runtime_config = strategy.get_trainer_runtime_config() trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_communicator_flags = trainer_runtime_config.get_communicator_flags( trainer_communicator_flags = trainer_runtime_config.get_communicator_flags(
) )
self.assertIn('send_queue_size', trainer_communicator_flags) self.assertIn('communicator_send_queue_size',
self.assertEqual(trainer_communicator_flags['send_queue_size'], 100) trainer_communicator_flags)
self.assertEqual(
trainer_communicator_flags['communicator_send_queue_size'], 100)
# test set_trainer_runtime_config exception # test set_trainer_runtime_config exception
trainer_runtime_config_dict['unknown'] = None trainer_runtime_config_dict['unknown'] = None
...@@ -138,9 +140,8 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -138,9 +140,8 @@ class TestStrategyFactor(unittest.TestCase):
def test_half_async_strategy(self): def test_half_async_strategy(self):
strategy = StrategyFactory.create_half_async_strategy() strategy = StrategyFactory.create_half_async_strategy()
self.assertEqual(strategy._program_config.sync_mode, False) self.assertEqual(strategy._program_config.sync_mode, False)
self.assertEqual(strategy._program_config.runtime_split_send_recv, self.assertEqual(strategy._program_config.runtime_split_send_recv, True)
False) self.assertEqual(strategy._build_strategy.async_mode, True)
self.assertEqual(strategy._build_strategy.async_mode, False)
# test set_server_runtime_config using ServerRuntimeConfig # test set_server_runtime_config using ServerRuntimeConfig
server_runtime_config_class = ServerRuntimeConfig() server_runtime_config_class = ServerRuntimeConfig()
......
...@@ -100,9 +100,10 @@ class FleetTest(unittest.TestCase): ...@@ -100,9 +100,10 @@ class FleetTest(unittest.TestCase):
self.assertRaises(Exception, fleet._transpile, "config") self.assertRaises(Exception, fleet._transpile, "config")
def set_program(self, avg_cost, strategy): def set_program(self, avg_cost, strategy):
optimizer = fluid.optimizer.SGD(0.1) with fluid.scope_guard(fluid.Scope()):
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fluid.optimizer.SGD(0.1)
optimizer.minimize(avg_cost) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
def test_init_role(self): def test_init_role(self):
role = role_maker.UserDefinedRoleMaker( role = role_maker.UserDefinedRoleMaker(
...@@ -123,6 +124,27 @@ class FleetTest(unittest.TestCase): ...@@ -123,6 +124,27 @@ class FleetTest(unittest.TestCase):
self.assertRaises(Exception, self.set_program, avg_cost, strategy) self.assertRaises(Exception, self.set_program, avg_cost, strategy)
def test_transpile(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.SERVER,
worker_num=2,
server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
# for test optimizer without init(role)
fleet.init(role)
batch_size = 128
is_sparse = True
is_distribute = False
strategy = DistributeTranspilerConfig()
strategy.sync_mode = False
strategy.runtime_split_send_recv = True
avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)
self.set_program(avg_cost, strategy)
strategy.runtime_split_send_recv = False
self.set_program(avg_cost, strategy)
class TranspilerOptimizerTest(unittest.TestCase): class TranspilerOptimizerTest(unittest.TestCase):
def testInvalidInputs(self): def testInvalidInputs(self):
......
...@@ -108,6 +108,9 @@ class TrainerDesc(object): ...@@ -108,6 +108,9 @@ class TrainerDesc(object):
for param in dump_param: for param in dump_param:
self.proto_desc.dump_param.append(param) self.proto_desc.dump_param.append(param)
def _set_thread_barrier(self, thread_barrier):
self.proto_desc.thread_barrier = thread_barrier
def _set_check_nan_var_names(self, check_nan_var_names): def _set_check_nan_var_names(self, check_nan_var_names):
for var in check_nan_var_names: for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var) self.proto_desc.check_nan_var_names.append(var)
......
...@@ -190,6 +190,9 @@ class DistributeTranspilerConfig(object): ...@@ -190,6 +190,9 @@ class DistributeTranspilerConfig(object):
__runtime_split_send_recv = False __runtime_split_send_recv = False
__sync_mode = True __sync_mode = True
# half_async
half_async = False
# Geo-sgd algorithm # Geo-sgd algorithm
geo_sgd_mode = False geo_sgd_mode = False
geo_sgd_need_push_nums = 100 geo_sgd_need_push_nums = 100
...@@ -744,27 +747,15 @@ class DistributeTranspiler(object): ...@@ -744,27 +747,15 @@ class DistributeTranspiler(object):
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
send_vars.append(var) send_vars.append(var)
if self.sync_mode: send_barrier_out = program.global_block().create_var(
fetch_barrier_input = [] name=framework.generate_control_dev_var_name())
send_barrier_out = program.global_block().create_var( if self.has_distributed_lookup_table:
name=framework.generate_control_dev_var_name()) self.grad_name_to_send_dummy_out[
if self.has_distributed_lookup_table: self.table_name] = program.global_block().create_var(
self.grad_name_to_send_dummy_out[ name=framework.generate_control_dev_var_name())
self.table_name] = program.global_block().create_var( input_deps = list(self.grad_name_to_send_dummy_out.values())
name=framework.generate_control_dev_var_name())
input_deps = list(self.grad_name_to_send_dummy_out.values())
program.global_block().append_op( if not self.sync_mode:
type="send_barrier",
inputs={"X": list(input_deps)},
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input.append(send_barrier_out)
else:
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
if len(lr_ops) > 0 and self.counter_var: if len(lr_ops) > 0 and self.counter_var:
decay_dummy_output = program.global_block().create_var( decay_dummy_output = program.global_block().create_var(
...@@ -789,6 +780,35 @@ class DistributeTranspiler(object): ...@@ -789,6 +780,35 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
[self.counter_var.name, self.counter_var.name] [self.counter_var.name, self.counter_var.name]
}) })
input_deps.append(decay_dummy_output)
if self.sync_mode:
fetch_barrier_input = []
program.global_block().append_op(
type="send_barrier",
inputs={"X": list(input_deps)},
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"trainer_id": self.trainer_id,
"half_async": False,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input.append(send_barrier_out)
else:
if self.config.runtime_split_send_recv and self.config.half_async:
program.global_block().append_op(
type="send_barrier",
inputs={"X": list(input_deps)},
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"trainer_id": self.trainer_id,
"half_async": True,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
# step 3: insert recv op to receive parameters from parameter server # step 3: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
...@@ -859,8 +879,6 @@ class DistributeTranspiler(object): ...@@ -859,8 +879,6 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name] [param_varname, recv_op_role_var_name]
}) })
if self.sync_mode:
fetch_barrier_input.extend(splited_var)
self._update_remote_sparse_update_op(program, need_sparse_update_params) self._update_remote_sparse_update_op(program, need_sparse_update_params)
...@@ -877,10 +895,11 @@ class DistributeTranspiler(object): ...@@ -877,10 +895,11 @@ class DistributeTranspiler(object):
}) })
for param_varname, splited_var in six.iteritems(self.param_var_mapping): for param_varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
if len(splited_var if not self.config.runtime_split_send_recv:
) > 1 and not self.config.runtime_split_send_recv:
program.global_block().append_op( program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册