提交 c1650120 编写于 作者: D dongdaxiang

refine device_worker and trainer code

test=develop
上级 8a335b50
......@@ -203,7 +203,7 @@ if(WITH_PSLIB)
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table
trainer_desc_proto glog lod_rank_table fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer)
else()
......@@ -212,7 +212,7 @@ else()
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table
trainer_desc_proto glog lod_rank_table fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer)
endif(WITH_PSLIB)
......
......@@ -26,7 +26,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
......@@ -161,7 +163,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer
trainer->Initialize(trainer_desc);
// trainer->SetRootScope(root_scope_);
trainer->SetScope(root_scope_);
trainer->SetDebug(debug);
// prepare training environment and helper environment
trainer->InitTrainerEnv(main_program, place_);
......
......@@ -75,14 +75,6 @@ class AsyncExecutor {
void InitModel();
void SaveModel(const std::string& path);
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
public:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_;
......
......@@ -39,12 +39,11 @@ namespace framework {
class PullDenseWorker {
public:
PullDenseWorker() {}
virtual ~PullDenseWorker() {}
virtual void Initialize(const TrainerDesc& param);
int Start();
void Stop();
void SetScope(Scope* scope) { root_scope_ = scope; }
void SetRootScope(Scope* scope) { root_scope_ = scope; }
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
......@@ -57,6 +56,7 @@ class PullDenseWorker {
}
private:
PullDenseWorker() : root_scope_(NULL) {}
void Run();
bool CheckUpdateParam(uint64_t table_id);
......@@ -137,20 +137,18 @@ class HogwildWorker : public CPUWorkerBase {
protected:
void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program);
std::shared_ptr<DataFeed> thread_reader_;
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_;
platform::Place place_;
};
class DownpourWorker : public HogwildWorker {
public:
DownpourWorker() {}
virtual ~DownpourWorker() {}
virtual void Initilize(const TrainerDesc& desc);
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
protected:
......@@ -163,7 +161,7 @@ class DownpourWorker : public HogwildWorker {
private:
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::string label_var_name_;
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
......
......@@ -19,25 +19,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types;
......@@ -59,7 +40,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return g_device_worker_map[device_worker_class]();
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class DeviceWorkerFactory {
public:
static std::string DeviceWorkerTypeList();
static std::shared_ptr<DeviceWorker> CreateDeviceWorker(
std::string device_worker_class);
};
} // namespace framework
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace paddle {
namespace framework {
......@@ -34,6 +35,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]);
workers_[i]->Initialize(trainer_desc);
}
std::vector<std::string> filelist_vec;
......@@ -41,13 +43,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
fleet_ptr_ = FleetWrapper::GetInstance();
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetScope(root_scope_);
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start();
}
......@@ -58,5 +62,6 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_->Stop();
}
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace paddle {
namespace framework {
void DownpourWorker::Initilize(const TrainerDesc& desc) {
void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
......@@ -37,6 +37,7 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
}
label_var_name_[table_id] = table.label_var_name();
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
......@@ -56,15 +57,18 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
label_var_name_ = param_.label_var_name();
skip_ops_.resize(param_.skip_ops_size());
}
void DownpourWorker::CollectLabelInfo(size_t table_id) {
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
auto table = param_.sparse_table(table_idx);
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size());
Variable* var = thread_scope_->FindVar(label_var_name_);
Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label_ptr = tensor->data<int64_t>();
......@@ -75,13 +79,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_id) {
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1
for (auto ins_idx = 0u; ins_idx < tensor->lod()[0].size() - 1; ++ins_idx) {
for (; fea_idx < tensor->lod()[0][ins_idx]; ++fea_idx) {
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
// should be skipped feasign defined in protobuf
if (ids[fea_idx] == 0u) {
continue;
}
feature_label[global_index++] = static_cast<float>(label_ptr[ins_idx]);
feature_label[global_index++] =
static_cast<float>(label_ptr[lod_idx - 1]);
}
}
}
......@@ -128,10 +133,10 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
void DownpourWorker::TrainFiles() {
platform::SetNumThreads(1);
thread_reader_->Start();
device_reader_->Start();
int batch_cnt = 0;
int cur_batch;
while ((cur_batch = thread_reader_->Next()) > 0) {
while ((cur_batch = device_reader_->Next()) > 0) {
// pull sparse here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
......@@ -144,8 +149,17 @@ void DownpourWorker::TrainFiles() {
// do computation here
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
// push gradients here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
......@@ -198,10 +212,12 @@ void DownpourWorker::TrainFiles() {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
thread_scope_->DropKids();
++batch_cnt;
}
}
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // end namespace framework
} // end namespace paddle
cc_library(fleet_wrapper SRCS fleet_wrapper.cc)
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS pslib_brpc pslib)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -19,10 +33,16 @@ namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif
void FleetWrapper::InitServer(const std::string& dist_desc, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
LOG(WARNING) << "Going to init server";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_server(dist_desc, index);
......@@ -38,6 +58,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
LOG(WARNING) << "Going to init server";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc,
......@@ -52,12 +73,14 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
void FleetWrapper::StopServer() {
#ifdef PADDLE_WITH_PSLIB
LOG(WARNING) << "Going to stop server";
pslib_ptr_->stop_server();
#endif
}
uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB
LOG(WARNING) << "Going to run server";
return pslib_ptr_->run_server();
#else
return 0;
......@@ -67,6 +90,7 @@ uint64_t FleetWrapper::RunServer() {
void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
#ifdef PADDLE_WITH_PSLIB
LOG(WARNING) << "Going to gather server ips";
pslib_ptr_->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num);
#endif
......@@ -122,13 +146,13 @@ void FleetWrapper::PullDenseVarsAsync(
std::vector<::std::future<int32_t>>* pull_dense_status) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
regions.reserve(var_names.size());
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) {
Variable* var = scope.FindVar(var_names[i]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
regions[i] = std::move(reg);
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
......@@ -186,7 +210,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
int offset = 2;
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_key_names[i]);
LOG(WARNING) << "sparse key names[" << i << "]: " << sparse_key_names[i];
LOG(WARNING) << "sparse grad names[" << i << "]: " << sparse_grad_names[i];
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
......@@ -201,16 +228,26 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
exit(-1);
}
int len = tensor->numel();
LOG(WARNING) << " tensor len: " << len;
int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
g += emb_dim;
continue;
}
LOG(WARNING) << "going to memcpy";
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
LOG(WARNING) << "show";
(*push_values)[fea_idx][0] = 1.0f;
LOG(WARNING) << "click";
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
LOG(WARNING) << "offset";
g += emb_dim;
fea_idx++;
}
......
......@@ -47,7 +47,6 @@ namespace framework {
class FleetWrapper {
public:
FleetWrapper() {}
virtual ~FleetWrapper() {}
// Pull sparse variables from server in Sync mode
......@@ -122,8 +121,11 @@ class FleetWrapper {
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif
private:
FleetWrapper() {}
protected:
bool is_initialized_;
static bool is_initialized_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace paddle {
......@@ -50,9 +51,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias();
device_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
device_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
......@@ -63,7 +64,7 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
void HogwildWorker::TrainFilesWithProfiler() {
platform::SetNumThreads(1);
thread_reader_->Start();
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
......@@ -79,7 +80,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int cur_batch;
int batch_cnt = 0;
timeline.Start();
while ((cur_batch = thread_reader_->Next()) > 0) {
while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
......@@ -115,10 +116,10 @@ void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1);
// how to accumulate fetched values here
thread_reader_->Start();
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
......@@ -128,5 +129,6 @@ void HogwildWorker::TrainFiles() {
}
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
} // end namespace framework
} // end namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace paddle {
namespace framework {
......@@ -65,5 +66,6 @@ void MultiTrainer::Finalize() {
}
}
REGISTER_TRAINER_CLASS(MultiTrainer);
} // end namespace framework
} // end namespace paddle
......@@ -20,24 +20,31 @@ namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
void PullDenseWorker::Initialize(const TrainerDesc& param) {
LOG(WARNING) << "going to initialize pull dense worker";
running_ = false;
param_ = param.pull_dense_param();
threshold_ = param_.threshold();
thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms();
LOG(WARNING) << "dense table size: " << param_.dense_table_size();
LOG(WARNING) << "thread num: " << thread_num_;
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
// setup dense variables for each table
int var_num = param_.dense_table(i).dense_value_name_size();
LOG(WARNING) << "var num: " << var_num;
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
LOG(WARNING) << "dense value names " << j << " "
<< dense_value_names_[tid][j];
}
// setup training version for each table
training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0;
current_version_[tid] = 0;
}
LOG(WARNING) << "initialize pull dense worker done.";
}
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
......@@ -56,6 +63,7 @@ void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
<< " Times";
exit(-1);
}
status_vec->resize(0);
}
void PullDenseWorker::Stop() {
......@@ -90,7 +98,10 @@ void PullDenseWorker::Run() {
}
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
LOG(WARNING) << "increase thread version input: " << thread_id << " table id "
<< table_id;
std::lock_guard<std::mutex> lock(mutex_for_version_);
LOG(WARNING) << "going to increase";
training_versions_[table_id][thread_id]++;
}
......
......@@ -19,7 +19,5 @@ namespace framework {
void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; }
void TrainerBase::Initialize(const TrainerDesc& trainer_desc) { return; }
} // end namespace framework
} // end namespace paddle
......@@ -39,8 +39,8 @@ class TrainerBase {
virtual ~TrainerBase() {}
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void Initialize(const TrainerDesc& trainer_desc);
void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
......
......@@ -43,7 +43,6 @@ message DownpourWorkerParameter {
repeated TableParameter sparse_table = 1;
repeated TableParameter dense_table = 2;
repeated string skip_ops = 3;
optional string label_var_name = 4;
}
message PullDenseWorkerParameter {
......
......@@ -21,23 +21,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std::string TrainerFactory::TrainerTypeList() {
std::string trainer_types;
......@@ -58,7 +41,5 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
return g_trainer_map[trainer_class]();
}
REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class TrainerFactory {
public:
static std::string TrainerTypeList();
static std::shared_ptr<TrainerBase> CreateTrainer(std::string trainer_class);
};
} // namespace framework
} // namespace paddle
......@@ -110,15 +110,17 @@ class AsyncExecutor(object):
is_local = self.instance == None
trainer = None
if is_local:
trainer = MultiTrainer(data_feed=data_feed, worker="Hogwild")
trainer = MultiTrainer()
else:
trainer = DistMultiTrainer(
data_feed, worker="Downpour", fleet_desc=self.dist_desc)
# define a trainer and a device_worker here
trainer = DistMultiTrainer()
trainer.gen_trainer_desc(
dataset=data_feed, fleet_desc=self.dist_desc, worker="downpour")
trainer.set_thread(thread_num)
trainer.set_filelist(filelist)
trainer.set_data_feed(data_feed)
with open("trainer_desc.proto", "w") as fout:
fout.write(trainer._desc())
# define a trainer and a device_worker here
self.executor.run_from_files(program_desc, trainer._desc(), debug)
'''
......@@ -284,8 +286,9 @@ class AsyncExecutor(object):
raise ValueError(
'instance is None, please run config_distributed_nodes init instance'
)
self.init_desc = init_desc
self.executor.init_server(dist_desc, self.instance._rankid)
self.dist_desc_str = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
self.executor.init_server(self.dist_desc_str, self.instance._rankid)
ip = self.executor.start_server()
self.instance.set_ip(ip)
self.instance.barrier_all() #wait all server start
......@@ -306,6 +309,7 @@ class AsyncExecutor(object):
'instance is None, please run config_distributed_nodes init instance'
)
self.dist_desc_str = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
place = core.CPUPlace()
executor = Executor(place)
......@@ -313,7 +317,7 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips,
self.executor.init_worker(self.dist_desc_str, ips,
self.instance.get_node_cnt(),
self.instance._rankid)
self.instance.barrier_all() #wait all worker start
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class DeviceWorker(object):
def __init__(self):
pass
def gen_worker_desc(self, trainer_desc, fleet_desc):
pass
class Hogwild(DeviceWorker):
def __init__(self):
super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc):
trainer_desc.device_worker_name = "HogwildWorker"
class Downpour(DeviceWorker):
def __init__(self):
super(Downpour, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc):
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name)
dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id
downpour = trainer_desc.downpour_param
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
sparse_table.label_var_name = "click"
dense_table = downpour.dense_table.add()
dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id
dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name)
dense_table.dense_grad_name.extend(fleet_desc.trainer_param.dense_table[
0].dense_gradient_variable_name)
downpour.skip_ops.extend(fleet_desc.trainer_param.skip_op)
class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type):
classname = worker_type.capitalize()
print("------------")
print(classname)
return globals()[classname]()
......@@ -13,7 +13,8 @@
# limitations under the License.
from paddle.fluid.proto import trainer_desc_pb2
import ps_pb2 as pslib
from distributed import ps_pb2 as ps_pb2
from device_worker import DeviceWorkerFactory
from google.protobuf import text_format
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer']
......@@ -28,16 +29,22 @@ class TrainerDesc(object):
text_format.Parse(f.read(), self.proto_desc)
'''
self.proto_desc = trainer_desc_pb2.TrainerDesc()
self.proto_desc.thread_num = 12
def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num
def set_filelist(self, filelist):
self.proto_desc.filelist.extend(filelist)
self.proto_desc.thread_num = min(
len(filelist), self.proto_desc.thread_num)
def set_data_feed(self, datafeed):
self.proto_desc.data_desc.CopyFrom(datafeed.proto_desc)
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker=None):
pass
def _desc(self):
return text_format.MessageToString(self.proto_desc)
......@@ -52,41 +59,20 @@ class MultiTrainer(TrainerDesc):
raise ValueError('ValueError: DeviceWorker %s '
'is not supported in MultiTrainer' % worker)
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker="Hogwild"):
super(MultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
class DistMultiTrainer(TrainerDesc):
def __init__(self, dataset=None, worker='Downpour', fleet_desc=None):
def __init__(self):
super(DistMultiTrainer, self).__init__()
if worker == "Downpour":
self.proto_desc.device_worker_name = worker + "Worker"
self.proto_desc.class_name = "DistMultiTrainer"
self.proto_desc.data_feed.CopyFrom(dataset)
downpour = self.proto_desc.downpour_param.add()
# sparse table should specify:
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table.table_id
sparse_table.sparse_key_name.CopyFrom(fleet_desc.trainer_param()
.sparse_table().slot_key())
sparse_table.sparse_value_name.CopyFrom(fleet_desc.trainer_param(
).sparse_table().slot_value())
sparse_table.sparse_grad_name.CopyFrom(fleet_desc.trainer_param(
).sparse_table().slot_gradient())
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param.accessor.fea_dim - 2
sparse_table.fea_dim = downpour.emb_dim + 2
sparse_table.label_var_name = "click"
pass
# dense table should specify:
dense_table = downpour.dense_table.add()
dense_table.table_id = \
fleet_desc.trainer_param.dense_table.table_id
# dense_value_name
dense_table.dense_value_name.CopyFrom(fleet_desc.trainer_param(
).dense_table().dense_variable_name)
# dense_grad_name
dense_table.dense_grad_name.CopyFrom(fleet_desc.trainer_param(
).dense_table().dense_gradient_name)
downpour.skipped_ops.extend(fleet_desc.trainer_param.skip_op)
print(str(self.proto_desc))
else:
raise ValueError('ValueError: DeviceWorker %s '
'is not supported in DistMultiTrainer' % worker)
def gen_trainer_desc(self, dataset=None, fleet_desc=None,
worker="Downpour"):
super(DistMultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
self.proto_desc.class_name = "DistMultiTrainer"
self.proto_desc.data_desc.CopyFrom(dataset.proto_desc)
worker_builder = DeviceWorkerFactory()
device_worker = worker_builder.create_device_worker("Downpour")
device_worker.gen_worker_desc(self.proto_desc, fleet_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册