提交 39014b9f 编写于 作者: D dongdaxiang

fix class register problem

上级 f0dd1201
......@@ -64,7 +64,6 @@ void AsyncExecutor::InitModel() {}
void AsyncExecutor::SaveModel(const std::string& path) {}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
<<<<<<< HEAD
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
......@@ -153,25 +152,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
_pull_dense_thread->stop();
}
#endif
=======
const std::string& trainer_desc_str,
const bool debug) {
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
std::shared_ptr<TrainerBase> trainer;
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer
trainer->Initialize(trainer_desc);
trainer->SetScope(root_scope_);
trainer->SetDebug(debug);
// prepare training environment and helper environment
trainer->InitTrainerEnv(main_program, place_);
trainer->InitOtherEnv(main_program);
// training and finalize training
trainer->Run();
trainer->Finalize();
>>>>>>> add dist_multi_trainer for distributed training, add trainer_factory and device_worker_factory so that we can easily extend new training mode, add pull dense worker which is a singleton for parameter fetching
VLOG(3) << "start to run from files in async_executor";
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids();
return;
......
......@@ -20,6 +20,25 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types;
for (auto iter = g_device_worker_map.begin();
......@@ -40,5 +59,7 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return g_device_worker_map[device_worker_class]();
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework
} // namespace paddle
......@@ -21,25 +21,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class DeviceWorkerFactory {
public:
static std::string DeviceWorkerTypeList();
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace paddle {
namespace framework {
......@@ -48,11 +47,13 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
fleet_ptr_ = FleetWrapper::GetInstance();
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start();
VLOG(3) << "init other env done.";
}
void DistMultiTrainer::Finalize() {
......@@ -62,6 +63,5 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_->Stop();
}
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // end namespace framework
} // end namespace paddle
......@@ -134,6 +134,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
}
void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
device_reader_->Start();
int batch_cnt = 0;
......@@ -148,6 +149,7 @@ void DownpourWorker::TrainFiles() {
CollectLabelInfo(i);
FillSparseValue(i);
}
VLOG(3) << "fill sparse value for all sparse table done.";
// do computation here
for (auto& op : ops_) {
......@@ -179,6 +181,7 @@ void DownpourWorker::TrainFiles() {
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
VLOG(3) << "push sparse and dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
......@@ -210,16 +213,17 @@ void DownpourWorker::TrainFiles() {
push_sparse_status_.resize(0);
}
/*
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
*/
thread_scope_->DropKids();
++batch_cnt;
}
}
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // end namespace framework
} // end namespace paddle
......@@ -129,6 +129,5 @@ void HogwildWorker::TrainFiles() {
}
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
} // end namespace framework
} // end namespace paddle
......@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace paddle {
namespace framework {
......@@ -66,6 +65,5 @@ void MultiTrainer::Finalize() {
}
}
REGISTER_TRAINER_CLASS(MultiTrainer);
} // end namespace framework
} // end namespace paddle
......@@ -22,8 +22,24 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std::string TrainerFactory::TrainerTypeList() {
std::string trainer_types;
for (auto iter = g_trainer_map.begin(); iter != g_trainer_map.end(); ++iter) {
......@@ -38,10 +54,14 @@ std::string TrainerFactory::TrainerTypeList() {
std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
std::string trainer_class) {
if (g_trainer_map.count(trainer_class) < 1) {
LOG(WARNING) << "Trainer class: " << trainer_class << " not defined";
LOG(WARNING) << TrainerTypeList();
exit(-1);
}
return g_trainer_map[trainer_class]();
}
REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // namespace framework
} // namespace paddle
......@@ -20,23 +20,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
extern trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class TrainerFactory {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册