提交 c1650120 编写于 作者: D dongdaxiang

refine device_worker and trainer code

test=develop
上级 8a335b50
...@@ -203,7 +203,7 @@ if(WITH_PSLIB) ...@@ -203,7 +203,7 @@ if(WITH_PSLIB)
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table trainer_desc_proto glog lod_rank_table fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer) variable_helper pslib_brpc pslib timer)
else() else()
...@@ -212,7 +212,7 @@ else() ...@@ -212,7 +212,7 @@ else()
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table trainer_desc_proto glog lod_rank_table fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer) variable_helper timer)
endif(WITH_PSLIB) endif(WITH_PSLIB)
......
...@@ -26,7 +26,9 @@ limitations under the License. */ ...@@ -26,7 +26,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
...@@ -161,7 +163,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -161,7 +163,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name()); trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer // initialize trainer
trainer->Initialize(trainer_desc); trainer->Initialize(trainer_desc);
// trainer->SetRootScope(root_scope_); trainer->SetScope(root_scope_);
trainer->SetDebug(debug); trainer->SetDebug(debug);
// prepare training environment and helper environment // prepare training environment and helper environment
trainer->InitTrainerEnv(main_program, place_); trainer->InitTrainerEnv(main_program, place_);
......
...@@ -75,14 +75,6 @@ class AsyncExecutor { ...@@ -75,14 +75,6 @@ class AsyncExecutor {
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
public: public:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_; std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_; Scope* root_scope_;
......
...@@ -39,12 +39,11 @@ namespace framework { ...@@ -39,12 +39,11 @@ namespace framework {
class PullDenseWorker { class PullDenseWorker {
public: public:
PullDenseWorker() {}
virtual ~PullDenseWorker() {} virtual ~PullDenseWorker() {}
virtual void Initialize(const TrainerDesc& param); virtual void Initialize(const TrainerDesc& param);
int Start(); int Start();
void Stop(); void Stop();
void SetScope(Scope* scope) { root_scope_ = scope; } void SetRootScope(Scope* scope) { root_scope_ = scope; }
void IncreaseThreadVersion(int thread_id, uint64_t table_id); void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id); void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec); void Wait(std::vector<::std::future<int32_t>>* status_vec);
...@@ -57,6 +56,7 @@ class PullDenseWorker { ...@@ -57,6 +56,7 @@ class PullDenseWorker {
} }
private: private:
PullDenseWorker() : root_scope_(NULL) {}
void Run(); void Run();
bool CheckUpdateParam(uint64_t table_id); bool CheckUpdateParam(uint64_t table_id);
...@@ -137,20 +137,18 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -137,20 +137,18 @@ class HogwildWorker : public CPUWorkerBase {
protected: protected:
void CreateThreadOperators(const ProgramDesc& program); void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program);
std::shared_ptr<DataFeed> thread_reader_;
std::vector<std::string> op_names_; std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
Scope* thread_scope_; Scope* thread_scope_;
std::vector<std::string> fetch_var_names_; std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_; std::vector<std::vector<float>> fetch_values_;
platform::Place place_;
}; };
class DownpourWorker : public HogwildWorker { class DownpourWorker : public HogwildWorker {
public: public:
DownpourWorker() {} DownpourWorker() {}
virtual ~DownpourWorker() {} virtual ~DownpourWorker() {}
virtual void Initilize(const TrainerDesc& desc); virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles(); virtual void TrainFiles();
protected: protected:
...@@ -163,7 +161,7 @@ class DownpourWorker : public HogwildWorker { ...@@ -163,7 +161,7 @@ class DownpourWorker : public HogwildWorker {
private: private:
DownpourWorkerParameter param_; DownpourWorkerParameter param_;
// just save the value in param_ for easy access // just save the value in param_ for easy access
std::string label_var_name_; std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_; std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_; std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_; std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
......
...@@ -19,25 +19,6 @@ limitations under the License. */ ...@@ -19,25 +19,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() { std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types; std::string device_worker_types;
...@@ -59,7 +40,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker( ...@@ -59,7 +40,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return g_device_worker_map[device_worker_class](); return g_device_worker_map[device_worker_class]();
} }
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class DeviceWorkerFactory {
public:
static std::string DeviceWorkerTypeList();
static std::shared_ptr<DeviceWorker> CreateDeviceWorker(
std::string device_worker_class);
};
} // namespace framework
} // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -34,6 +35,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) { ...@@ -34,6 +35,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc()); readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]); workers_[i]->SetDataFeed(readers_[i]);
workers_[i]->Initialize(trainer_desc);
} }
std::vector<std::string> filelist_vec; std::vector<std::string> filelist_vec;
...@@ -41,13 +43,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) { ...@@ -41,13 +43,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
filelist_vec.push_back(trainer_desc.filelist(i)); filelist_vec.push_back(trainer_desc.filelist(i));
} }
readers_[0]->SetFileList(filelist_vec);
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
pull_dense_worker_ = PullDenseWorker::GetInstance(); pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc); pull_dense_worker_->Initialize(trainer_desc);
} }
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetScope(root_scope_); pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start(); pull_dense_worker_->Start();
} }
...@@ -58,5 +62,6 @@ void DistMultiTrainer::Finalize() { ...@@ -58,5 +62,6 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_->Stop(); pull_dense_worker_->Stop();
} }
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void DownpourWorker::Initilize(const TrainerDesc& desc) { void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param(); param_ = desc.downpour_param();
for (size_t i = 0; i < param_.sparse_table_size(); ++i) { for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id = uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id()); static_cast<uint64_t>(param_.sparse_table(i).table_id());
...@@ -37,6 +37,7 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) { ...@@ -37,6 +37,7 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) { for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j); sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
} }
label_var_name_[table_id] = table.label_var_name();
} }
for (size_t i = 0; i < param_.dense_table_size(); ++i) { for (size_t i = 0; i < param_.dense_table_size(); ++i) {
...@@ -56,15 +57,18 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) { ...@@ -56,15 +57,18 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for (size_t i = 0; i < param_.skip_ops_size(); ++i) { for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
skip_ops_.resize(param_.skip_ops_size());
label_var_name_ = param_.label_var_name();
} }
void DownpourWorker::CollectLabelInfo(size_t table_id) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
auto table = param_.sparse_table(table_idx);
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
auto& feature = features_[table_id]; auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id]; auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size()); feature_label.resize(feature.size());
Variable* var = thread_scope_->FindVar(label_var_name_); Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label_ptr = tensor->data<int64_t>(); int64_t* label_ptr = tensor->data<int64_t>();
...@@ -75,13 +79,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_id) { ...@@ -75,13 +79,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_id) {
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0; int fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1 // tensor->lod()[0].size() == batch_size + 1
for (auto ins_idx = 0u; ins_idx < tensor->lod()[0].size() - 1; ++ins_idx) { for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][ins_idx]; ++fea_idx) { for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
// should be skipped feasign defined in protobuf // should be skipped feasign defined in protobuf
if (ids[fea_idx] == 0u) { if (ids[fea_idx] == 0u) {
continue; continue;
} }
feature_label[global_index++] = static_cast<float>(label_ptr[ins_idx]); feature_label[global_index++] =
static_cast<float>(label_ptr[lod_idx - 1]);
} }
} }
} }
...@@ -128,10 +133,10 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -128,10 +133,10 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
void DownpourWorker::TrainFiles() { void DownpourWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
thread_reader_->Start(); device_reader_->Start();
int batch_cnt = 0; int batch_cnt = 0;
int cur_batch; int cur_batch;
while ((cur_batch = thread_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
// pull sparse here // pull sparse here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) { for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id()); uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
...@@ -144,7 +149,16 @@ void DownpourWorker::TrainFiles() { ...@@ -144,7 +149,16 @@ void DownpourWorker::TrainFiles() {
// do computation here // do computation here
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
} }
// push gradients here // push gradients here
...@@ -198,10 +212,12 @@ void DownpourWorker::TrainFiles() { ...@@ -198,10 +212,12 @@ void DownpourWorker::TrainFiles() {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id()); uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
} }
thread_scope_->DropKids(); thread_scope_->DropKids();
++batch_cnt; ++batch_cnt;
} }
} }
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
cc_library(fleet_wrapper SRCS fleet_wrapper.cc) cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS pslib_brpc pslib)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -19,10 +33,16 @@ namespace framework { ...@@ -19,10 +33,16 @@ namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL; std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif
void FleetWrapper::InitServer(const std::string& dist_desc, int index) { void FleetWrapper::InitServer(const std::string& dist_desc, int index) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) { if (!is_initialized_) {
LOG(WARNING) << "Going to init server";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>( pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib()); new paddle::distributed::PSlib());
pslib_ptr_->init_server(dist_desc, index); pslib_ptr_->init_server(dist_desc, index);
...@@ -38,6 +58,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, ...@@ -38,6 +58,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int node_num, int index) { int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) { if (!is_initialized_) {
LOG(WARNING) << "Going to init server";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>( pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib()); new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc, pslib_ptr_->init_worker(dist_desc,
...@@ -52,12 +73,14 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, ...@@ -52,12 +73,14 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
void FleetWrapper::StopServer() { void FleetWrapper::StopServer() {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
LOG(WARNING) << "Going to stop server";
pslib_ptr_->stop_server(); pslib_ptr_->stop_server();
#endif #endif
} }
uint64_t FleetWrapper::RunServer() { uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
LOG(WARNING) << "Going to run server";
return pslib_ptr_->run_server(); return pslib_ptr_->run_server();
#else #else
return 0; return 0;
...@@ -67,6 +90,7 @@ uint64_t FleetWrapper::RunServer() { ...@@ -67,6 +90,7 @@ uint64_t FleetWrapper::RunServer() {
void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list, void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) { int node_num) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
LOG(WARNING) << "Going to gather server ips";
pslib_ptr_->gather_servers(const_cast<uint64_t*>(host_sign_list.data()), pslib_ptr_->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num); node_num);
#endif #endif
...@@ -122,13 +146,13 @@ void FleetWrapper::PullDenseVarsAsync( ...@@ -122,13 +146,13 @@ void FleetWrapper::PullDenseVarsAsync(
std::vector<::std::future<int32_t>>* pull_dense_status) { std::vector<::std::future<int32_t>>* pull_dense_status) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
regions.reserve(var_names.size()); regions.resize(var_names.size());
for (auto& t : var_names) { for (auto i = 0u; i < var_names.size(); ++i) {
Variable* var = scope.FindVar(t); Variable* var = scope.FindVar(var_names[i]);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>(); float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel()); paddle::ps::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg)); regions[i] = std::move(reg);
} }
auto status = auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid); pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
...@@ -186,7 +210,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -186,7 +210,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
int offset = 2; int offset = 2;
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_key_names[i]); LOG(WARNING) << "sparse key names[" << i << "]: " << sparse_key_names[i];
LOG(WARNING) << "sparse grad names[" << i << "]: " << sparse_grad_names[i];
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>(); LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) { if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found"; LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
...@@ -201,16 +228,26 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -201,16 +228,26 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
exit(-1); exit(-1);
} }
int len = tensor->numel(); int len = tensor->numel();
LOG(WARNING) << " tensor len: " << len;
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) { for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) { if (ids[id_idx] == 0) {
g += emb_dim; g += emb_dim;
continue; continue;
} }
LOG(WARNING) << "going to memcpy";
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
LOG(WARNING) << "show";
(*push_values)[fea_idx][0] = 1.0f; (*push_values)[fea_idx][0] = 1.0f;
LOG(WARNING) << "click";
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]); (*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
LOG(WARNING) << "offset";
g += emb_dim; g += emb_dim;
fea_idx++; fea_idx++;
} }
......
...@@ -47,7 +47,6 @@ namespace framework { ...@@ -47,7 +47,6 @@ namespace framework {
class FleetWrapper { class FleetWrapper {
public: public:
FleetWrapper() {}
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
// Pull sparse variables from server in Sync mode // Pull sparse variables from server in Sync mode
...@@ -122,8 +121,11 @@ class FleetWrapper { ...@@ -122,8 +121,11 @@ class FleetWrapper {
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_; static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif #endif
private:
FleetWrapper() {}
protected: protected:
bool is_initialized_; static bool is_initialized_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper); DISABLE_COPY_AND_ASSIGN(FleetWrapper);
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
namespace paddle { namespace paddle {
...@@ -50,9 +51,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc& program) { ...@@ -50,9 +51,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
void HogwildWorker::BindingDataFeedMemory() { void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias(); device_reader_->GetUseSlotAlias();
for (auto name : input_feed) { for (auto name : input_feed) {
thread_reader_->AddFeedVar(thread_scope_->Var(name), name); device_reader_->AddFeedVar(thread_scope_->Var(name), name);
} }
} }
...@@ -63,7 +64,7 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) { ...@@ -63,7 +64,7 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
void HogwildWorker::TrainFilesWithProfiler() { void HogwildWorker::TrainFilesWithProfiler() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
thread_reader_->Start(); device_reader_->Start();
std::vector<double> op_total_time; std::vector<double> op_total_time;
std::vector<std::string> op_name; std::vector<std::string> op_name;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -79,7 +80,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -79,7 +80,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
timeline.Start(); timeline.Start();
while ((cur_batch = thread_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause(); timeline.Pause();
read_time += timeline.ElapsedSec(); read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -115,10 +116,10 @@ void HogwildWorker::TrainFiles() { ...@@ -115,10 +116,10 @@ void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
// how to accumulate fetched values here // how to accumulate fetched values here
thread_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
...@@ -128,5 +129,6 @@ void HogwildWorker::TrainFiles() { ...@@ -128,5 +129,6 @@ void HogwildWorker::TrainFiles() {
} }
} }
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -65,5 +66,6 @@ void MultiTrainer::Finalize() { ...@@ -65,5 +66,6 @@ void MultiTrainer::Finalize() {
} }
} }
REGISTER_TRAINER_CLASS(MultiTrainer);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -20,24 +20,31 @@ namespace framework { ...@@ -20,24 +20,31 @@ namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL; std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
void PullDenseWorker::Initialize(const TrainerDesc& param) { void PullDenseWorker::Initialize(const TrainerDesc& param) {
LOG(WARNING) << "going to initialize pull dense worker";
running_ = false; running_ = false;
param_ = param.pull_dense_param(); param_ = param.pull_dense_param();
threshold_ = param_.threshold(); threshold_ = param_.threshold();
thread_num_ = param_.device_num(); thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms(); sleep_time_ms_ = param_.sleep_time_ms();
LOG(WARNING) << "dense table size: " << param_.dense_table_size();
LOG(WARNING) << "thread num: " << thread_num_;
for (size_t i = 0; i < param_.dense_table_size(); ++i) { for (size_t i = 0; i < param_.dense_table_size(); ++i) {
// setup dense variables for each table // setup dense variables for each table
int var_num = param_.dense_table(i).dense_value_name_size(); int var_num = param_.dense_table(i).dense_value_name_size();
LOG(WARNING) << "var num: " << var_num;
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id()); uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
dense_value_names_[tid].resize(var_num); dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) { for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j); dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
LOG(WARNING) << "dense value names " << j << " "
<< dense_value_names_[tid][j];
} }
// setup training version for each table // setup training version for each table
training_versions_[tid].resize(thread_num_, 0); training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0; last_versions_[tid] = 0;
current_version_[tid] = 0; current_version_[tid] = 0;
} }
LOG(WARNING) << "initialize pull dense worker done.";
} }
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) { void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
...@@ -56,6 +63,7 @@ void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) { ...@@ -56,6 +63,7 @@ void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
<< " Times"; << " Times";
exit(-1); exit(-1);
} }
status_vec->resize(0);
} }
void PullDenseWorker::Stop() { void PullDenseWorker::Stop() {
...@@ -90,7 +98,10 @@ void PullDenseWorker::Run() { ...@@ -90,7 +98,10 @@ void PullDenseWorker::Run() {
} }
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) { void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
LOG(WARNING) << "increase thread version input: " << thread_id << " table id "
<< table_id;
std::lock_guard<std::mutex> lock(mutex_for_version_); std::lock_guard<std::mutex> lock(mutex_for_version_);
LOG(WARNING) << "going to increase";
training_versions_[table_id][thread_id]++; training_versions_[table_id][thread_id]++;
} }
......
...@@ -19,7 +19,5 @@ namespace framework { ...@@ -19,7 +19,5 @@ namespace framework {
void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; } void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; }
void TrainerBase::Initialize(const TrainerDesc& trainer_desc) { return; }
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -39,8 +39,8 @@ class TrainerBase { ...@@ -39,8 +39,8 @@ class TrainerBase {
virtual ~TrainerBase() {} virtual ~TrainerBase() {}
// model memory are hosted in root_scope // model memory are hosted in root_scope
void SetScope(Scope* root_scope); void SetScope(Scope* root_scope);
void Initialize(const TrainerDesc& trainer_desc);
void SetDebug(const bool debug) { debug_ = debug; } void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program, virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0; const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0; virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
......
...@@ -43,7 +43,6 @@ message DownpourWorkerParameter { ...@@ -43,7 +43,6 @@ message DownpourWorkerParameter {
repeated TableParameter sparse_table = 1; repeated TableParameter sparse_table = 1;
repeated TableParameter dense_table = 2; repeated TableParameter dense_table = 2;
repeated string skip_ops = 3; repeated string skip_ops = 3;
optional string label_var_name = 4;
} }
message PullDenseWorkerParameter { message PullDenseWorkerParameter {
......
...@@ -21,23 +21,6 @@ limitations under the License. */ ...@@ -21,23 +21,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std::string TrainerFactory::TrainerTypeList() { std::string TrainerFactory::TrainerTypeList() {
std::string trainer_types; std::string trainer_types;
...@@ -58,7 +41,5 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer( ...@@ -58,7 +41,5 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
return g_trainer_map[trainer_class](); return g_trainer_map[trainer_class]();
} }
REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class TrainerFactory {
public:
static std::string TrainerTypeList();
static std::shared_ptr<TrainerBase> CreateTrainer(std::string trainer_class);
};
} // namespace framework
} // namespace paddle
...@@ -110,15 +110,17 @@ class AsyncExecutor(object): ...@@ -110,15 +110,17 @@ class AsyncExecutor(object):
is_local = self.instance == None is_local = self.instance == None
trainer = None trainer = None
if is_local: if is_local:
trainer = MultiTrainer(data_feed=data_feed, worker="Hogwild") trainer = MultiTrainer()
else: else:
trainer = DistMultiTrainer( trainer = DistMultiTrainer()
data_feed, worker="Downpour", fleet_desc=self.dist_desc) trainer.gen_trainer_desc(
dataset=data_feed, fleet_desc=self.dist_desc, worker="downpour")
# define a trainer and a device_worker here
trainer.set_thread(thread_num) trainer.set_thread(thread_num)
trainer.set_filelist(filelist) trainer.set_filelist(filelist)
trainer.set_data_feed(data_feed) trainer.set_data_feed(data_feed)
with open("trainer_desc.proto", "w") as fout:
fout.write(trainer._desc())
# define a trainer and a device_worker here
self.executor.run_from_files(program_desc, trainer._desc(), debug) self.executor.run_from_files(program_desc, trainer._desc(), debug)
''' '''
...@@ -284,8 +286,9 @@ class AsyncExecutor(object): ...@@ -284,8 +286,9 @@ class AsyncExecutor(object):
raise ValueError( raise ValueError(
'instance is None, please run config_distributed_nodes init instance' 'instance is None, please run config_distributed_nodes init instance'
) )
self.init_desc = init_desc self.dist_desc_str = text_format.MessageToString(dist_desc)
self.executor.init_server(dist_desc, self.instance._rankid) self.dist_desc = dist_desc
self.executor.init_server(self.dist_desc_str, self.instance._rankid)
ip = self.executor.start_server() ip = self.executor.start_server()
self.instance.set_ip(ip) self.instance.set_ip(ip)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
...@@ -306,6 +309,7 @@ class AsyncExecutor(object): ...@@ -306,6 +309,7 @@ class AsyncExecutor(object):
'instance is None, please run config_distributed_nodes init instance' 'instance is None, please run config_distributed_nodes init instance'
) )
self.dist_desc_str = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc self.dist_desc = dist_desc
place = core.CPUPlace() place = core.CPUPlace()
executor = Executor(place) executor = Executor(place)
...@@ -313,7 +317,7 @@ class AsyncExecutor(object): ...@@ -313,7 +317,7 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.executor.init_worker(self.dist_desc_str, ips,
self.instance.get_node_cnt(), self.instance.get_node_cnt(),
self.instance._rankid) self.instance._rankid)
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class DeviceWorker(object):
def __init__(self):
pass
def gen_worker_desc(self, trainer_desc, fleet_desc):
pass
class Hogwild(DeviceWorker):
def __init__(self):
super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc):
trainer_desc.device_worker_name = "HogwildWorker"
class Downpour(DeviceWorker):
def __init__(self):
super(Downpour, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc):
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name)
dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id
downpour = trainer_desc.downpour_param
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
sparse_table.label_var_name = "click"
dense_table = downpour.dense_table.add()
dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id
dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name)
dense_table.dense_grad_name.extend(fleet_desc.trainer_param.dense_table[
0].dense_gradient_variable_name)
downpour.skip_ops.extend(fleet_desc.trainer_param.skip_op)
class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type):
classname = worker_type.capitalize()
print("------------")
print(classname)
return globals()[classname]()
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from paddle.fluid.proto import trainer_desc_pb2 from paddle.fluid.proto import trainer_desc_pb2
import ps_pb2 as pslib from distributed import ps_pb2 as ps_pb2
from device_worker import DeviceWorkerFactory
from google.protobuf import text_format from google.protobuf import text_format
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer'] __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer']
...@@ -28,16 +29,22 @@ class TrainerDesc(object): ...@@ -28,16 +29,22 @@ class TrainerDesc(object):
text_format.Parse(f.read(), self.proto_desc) text_format.Parse(f.read(), self.proto_desc)
''' '''
self.proto_desc = trainer_desc_pb2.TrainerDesc() self.proto_desc = trainer_desc_pb2.TrainerDesc()
self.proto_desc.thread_num = 12
def set_thread(self, thread_num): def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num self.proto_desc.thread_num = thread_num
def set_filelist(self, filelist): def set_filelist(self, filelist):
self.proto_desc.filelist.extend(filelist) self.proto_desc.filelist.extend(filelist)
self.proto_desc.thread_num = min(
len(filelist), self.proto_desc.thread_num)
def set_data_feed(self, datafeed): def set_data_feed(self, datafeed):
self.proto_desc.data_desc.CopyFrom(datafeed.proto_desc) self.proto_desc.data_desc.CopyFrom(datafeed.proto_desc)
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker=None):
pass
def _desc(self): def _desc(self):
return text_format.MessageToString(self.proto_desc) return text_format.MessageToString(self.proto_desc)
...@@ -52,41 +59,20 @@ class MultiTrainer(TrainerDesc): ...@@ -52,41 +59,20 @@ class MultiTrainer(TrainerDesc):
raise ValueError('ValueError: DeviceWorker %s ' raise ValueError('ValueError: DeviceWorker %s '
'is not supported in MultiTrainer' % worker) 'is not supported in MultiTrainer' % worker)
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker="Hogwild"):
super(MultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
def __init__(self, dataset=None, worker='Downpour', fleet_desc=None): def __init__(self):
super(DistMultiTrainer, self).__init__() super(DistMultiTrainer, self).__init__()
if worker == "Downpour": pass
self.proto_desc.device_worker_name = worker + "Worker"
self.proto_desc.class_name = "DistMultiTrainer"
self.proto_desc.data_feed.CopyFrom(dataset)
downpour = self.proto_desc.downpour_param.add()
# sparse table should specify:
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table.table_id
sparse_table.sparse_key_name.CopyFrom(fleet_desc.trainer_param()
.sparse_table().slot_key())
sparse_table.sparse_value_name.CopyFrom(fleet_desc.trainer_param(
).sparse_table().slot_value())
sparse_table.sparse_grad_name.CopyFrom(fleet_desc.trainer_param(
).sparse_table().slot_gradient())
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param.accessor.fea_dim - 2
sparse_table.fea_dim = downpour.emb_dim + 2
sparse_table.label_var_name = "click"
# dense table should specify: def gen_trainer_desc(self, dataset=None, fleet_desc=None,
dense_table = downpour.dense_table.add() worker="Downpour"):
dense_table.table_id = \ super(DistMultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
fleet_desc.trainer_param.dense_table.table_id self.proto_desc.class_name = "DistMultiTrainer"
# dense_value_name self.proto_desc.data_desc.CopyFrom(dataset.proto_desc)
dense_table.dense_value_name.CopyFrom(fleet_desc.trainer_param( worker_builder = DeviceWorkerFactory()
).dense_table().dense_variable_name) device_worker = worker_builder.create_device_worker("Downpour")
# dense_grad_name device_worker.gen_worker_desc(self.proto_desc, fleet_desc)
dense_table.dense_grad_name.CopyFrom(fleet_desc.trainer_param(
).dense_table().dense_gradient_name)
downpour.skipped_ops.extend(fleet_desc.trainer_param.skip_op)
print(str(self.proto_desc))
else:
raise ValueError('ValueError: DeviceWorker %s '
'is not supported in DistMultiTrainer' % worker)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册