提交 855bf579 编写于 作者: D dongdaxiang

add dist_multi_trainer for distributed training, add trainer_factory and...

add dist_multi_trainer for distributed training, add trainer_factory and device_worker_factory so that we can easily extend new training mode, add pull dense worker which is a singleton for parameter fetching
上级 d4f63d82
...@@ -29,145 +29,31 @@ limitations under the License. */ ...@@ -29,145 +29,31 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place) AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
: root_scope_(scope), place_(place) {} : root_scope_(scope), place_(place) {}
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
const int thread_index, const bool debug) {
worker->SetThreadId(thread_index);
worker->SetDebug(debug);
worker->SetRootScope(root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->SetParamConfig(&_param_config);
#endif
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
const int thread_num, const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
}
readers[0]->SetFileList(filelist);
}
#ifdef PADDLE_WITH_PSLIB
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) { void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( fleet_ptr_ = FleetWrapper::GetInstance();
new paddle::distributed::PSlib()); fleet_ptr_->InitServer(dist_desc, index);
_pslib_ptr->init_server(dist_desc, index);
InitParamConfig();
} }
void AsyncExecutor::InitWorker(const std::string& dist_desc, void AsyncExecutor::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, const std::vector<uint64_t>& host_sign_list,
int node_num, int index) { int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( fleet_ptr_ = FleetWrapper::GetInstance();
new paddle::distributed::PSlib()); fleet_ptr_->InitWorker(dist_desc, host_sign_list, node_num, index);
_pslib_ptr->init_worker(
dist_desc, const_cast<uint64_t*>(host_sign_list.data()), node_num, index);
InitParamConfig();
} }
uint64_t AsyncExecutor::StartServer() { return _pslib_ptr->run_server(); } uint64_t AsyncExecutor::StartServer() { return fleet_ptr_->RunServer(); }
void AsyncExecutor::StopServer() { _pslib_ptr->stop_server(); } void AsyncExecutor::StopServer() { fleet_ptr_->StopServer(); }
void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list, void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) { int node_num) {
_pslib_ptr->gather_servers(const_cast<uint64_t*>(host_sign_list.data()), fleet_ptr_->GatherServers(host_sign_list, node_num);
node_num);
}
void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param_size();
++i) {
if (_pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.table_class()
.find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.accessor()
.fea_dim();
break;
}
}
_param_config.slot_dim = _param_config.fea_dim - 2;
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
++t) {
_param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] =
std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i));
}
std::vector<std::string> tmp_dense_gradient_variable_name;
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i));
}
_param_config.dense_variable_name[table.table_id()] =
std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim());
}
} }
void AsyncExecutor::InitModel() { void AsyncExecutor::InitModel() {
...@@ -217,22 +103,6 @@ void AsyncExecutor::SaveModel(const std::string& path) { ...@@ -217,22 +103,6 @@ void AsyncExecutor::SaveModel(const std::string& path) {
} }
} }
void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;
param.threshold = 1;
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread =
std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
}
#endif
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
......
...@@ -67,7 +67,7 @@ class AsyncExecutor { ...@@ -67,7 +67,7 @@ class AsyncExecutor {
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const std::string& mode, const bool debug = false); const std::string& mode, const bool debug = false);
#ifdef PADDLE_WITH_PSLIB
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc, void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num, const std::vector<uint64_t>& host_sign_list, int node_num,
...@@ -77,8 +77,6 @@ class AsyncExecutor { ...@@ -77,8 +77,6 @@ class AsyncExecutor {
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
void InitParamConfig();
#endif
private: private:
void CreateThreads(ExecutorThreadWorker* worker, void CreateThreads(ExecutorThreadWorker* worker,
...@@ -87,21 +85,14 @@ class AsyncExecutor { ...@@ -87,21 +85,14 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index, Scope* root_scope, const int thread_index,
const bool debug); const bool debug);
#ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode);
#endif
public: public:
#ifdef PADDLE_WITH_PSLIB std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
AsyncWorkerParamConfig _param_config;
#endif
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
private: private:
int actual_thread_num; int actual_thread_num_;
}; };
} // namespace framework } // namespace framework
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
class PullDenseWorker {
public:
PullDenseWorker() {}
virtual ~PullDenseWorker() {}
virtual void Initialize(const TrainerDesc& param);
int Start();
void Stop();
void SetScope(Scope* scope) { root_scope_ = scope; }
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
static std::shared_ptr<PullDenseWorker> s_instance_;
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
}
return s_instance_;
}
private:
void Run();
bool CheckUpdateParam(uint64_t table_id);
private:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
Scope* root_scope_;
bool running_;
std::map<uint64_t, uint64_t> last_versions_;
std::map<uint64_t, uint64_t> current_version_;
std::mutex mutex_for_version_;
std::map<uint64_t, std::vector<uint64_t>> training_versions_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::thread t_;
int thread_num_;
int sleep_time_ms_;
int threshold_;
std::vector<::std::future<int32_t>> pull_dense_status_;
uint32_t pull_dense_fail_times_ = 0;
std::vector<float> base_norm_param_;
std::vector<float> mean_;
std::vector<float> scale_;
float squared_sum_epsilon_ = 1e-4;
std::mutex mutex_for_mean_scale_;
float total_batch_num_ = 0;
};
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() {}
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() = 0;
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
// will make this zero copy in the future
virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
protected:
Scope* root_scope_;
paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_;
};
class CPUWorkerBase : public DeviceWorker {
public:
CPUWorkerBase() {}
virtual ~CPUWorkerBase() {}
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
protected:
int thread_id_;
};
class HogwildWorker : public CPUWorkerBase {
public:
HogwildWorker() {}
virtual ~HogwildWorker() {}
virtual void Initialize(const TrainerDesc& desc) {}
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();
protected:
void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program);
std::shared_ptr<DataFeed> thread_reader_;
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_;
platform::Place place_;
};
class DownpourWorker : public HogwildWorker {
public:
DownpourWorker() {}
virtual ~DownpourWorker() {}
virtual void Initilize(const TrainerDesc& desc);
virtual void TrainFiles();
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
void FillSparseValue(size_t table_id);
void PushGradients();
void CollectLabelInfo(size_t table_id);
private:
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::string label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
// feasign
std::map<uint64_t, std::vector<uint64_t>> features_;
// feasign stats
std::map<uint64_t, std::vector<float>> feature_labels_;
// feasign embedding
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
// feasign embedding gradient
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
// skipped ops
std::vector<std::string> skip_ops_;
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types;
for (auto iter = g_device_worker_map.begin();
iter != g_device_worker_map.end(); ++iter) {
if (iter != g_device_worker_map.begin()) {
device_worker_types += ", ";
}
device_worker_types += iter->first;
}
return device_worker_types;
}
std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
std::string device_worker_class) {
if (g_device_worker_map.count(device_worker_class) < 1) {
exit(-1);
}
return g_device_worker_map[device_worker_class]();
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]);
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
fleet_ptr_ = FleetWrapper::GetInstance();
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetScope(root_scope_);
pull_dense_worker_->Start();
}
void DistMultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
pull_dense_worker_->Stop();
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace paddle {
namespace framework {
void DownpourWorker::Initilize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
TableParameter table = param_.sparse_table(i);
sparse_key_names_[table_id].resize(table.sparse_key_name_size());
for (size_t j = 0; j < table.sparse_key_name_size(); ++j) {
sparse_key_names_[table_id][j] = table.sparse_key_name(j);
}
sparse_value_names_[table_id].resize(table.sparse_value_name_size());
for (size_t j = 0; j < table.sparse_value_name_size(); ++j) {
sparse_value_names_[table_id][j] = table.sparse_value_name(j);
}
sparse_grad_names_[table_id].resize(table.sparse_grad_name_size());
for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
}
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
dense_value_names_[table_id].resize(table.dense_value_name_size());
for (size_t j = 0; j < table.dense_value_name_size(); ++j) {
dense_value_names_[table_id][j] = table.dense_value_name(j);
}
dense_grad_names_[table_id].resize(table.dense_grad_name_size());
for (size_t j = 0; j < table.dense_grad_name_size(); ++j) {
dense_grad_names_[table_id][j] = table.dense_grad_name(j);
}
}
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
label_var_name_ = param_.label_var_name();
}
void DownpourWorker::CollectLabelInfo(size_t table_id) {
auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size());
Variable* var = thread_scope_->FindVar(label_var_name_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label_ptr = tensor->data<int64_t>();
int global_index = 0;
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1
for (auto ins_idx = 0u; ins_idx < tensor->lod()[0].size() - 1; ++ins_idx) {
for (; fea_idx < tensor->lod()[0][ins_idx]; ++fea_idx) {
// should be skipped feasign defined in protobuf
if (ids[fea_idx] == 0u) {
continue;
}
feature_label[global_index++] = static_cast<float>(label_ptr[ins_idx]);
}
}
}
CHECK(global_index == feature.size())
<< "expect fea info size:" << feature.size() << " real:" << global_index;
}
void DownpourWorker::FillSparseValue(size_t table_idx) {
auto table = param_.sparse_table(table_idx);
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
auto& fea_value = feature_values_[table_id];
auto fea_idx = 0u;
std::vector<float> init_value(table.emb_dim());
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
std::string slot_name = sparse_key_names_[table_id][i];
std::string emb_slot_name = sparse_value_names_[table_id][i];
Variable* var = thread_scope_->FindVar(slot_name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->mutable_data<float>({len, table.emb_dim()},
platform::CPUPlace());
memset(ptr, 0, sizeof(float) * len * table.emb_dim());
auto& tensor_lod = tensor->lod()[0];
LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod);
for (auto index = 0u; index < len; ++index) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
sizeof(float) * table.emb_dim());
fea_idx++;
}
}
}
void DownpourWorker::TrainFiles() {
platform::SetNumThreads(1);
thread_reader_->Start();
int batch_cnt = 0;
int cur_batch;
while ((cur_batch = thread_reader_->Next()) > 0) {
// pull sparse here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
fleet_ptr_->PullSparseVarsSync(
*thread_scope_, tid, sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], param_.sparse_table(i).fea_dim());
CollectLabelInfo(i);
FillSparseValue(i);
}
// do computation here
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
// push gradients here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid],
param_.sparse_table(i).emb_dim(), &feature_grads_[tid],
&push_sparse_status_);
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
thread_scope_->DropKids();
++batch_cnt;
}
}
} // end namespace framework
} // end namespace paddle
...@@ -513,7 +513,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) { ...@@ -513,7 +513,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& push_g = _feature_push_value[table_id]; auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, &push_g, fea_dim); check_pull_push_memory(features, &push_g, fea_dim);
collect_feasign_info(table_id); collect_feasign_info(table_id);
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
namespace paddle {
namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
void FleetWrapper::InitServer(const std::string& dist_desc, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_server(dist_desc, index);
is_initialized_ = true;
} else {
LOG(WARNING) << "Server can be initialized only once";
}
#endif
}
void FleetWrapper::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list,
int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc,
const_cast<uint64_t*>(host_sign_list.data()),
node_num, index);
is_initialized_ = true;
} else {
LOG(WARNING) << "Worker can be initialized only once";
}
#endif
}
void FleetWrapper::StopServer() {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->stop_server();
#endif
}
uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB
return pslib_ptr_->run_server();
#else
return 0;
#endif
}
void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num);
#endif
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
#ifdef PADDLE_WITH_PSLIB
std::vector<::std::future<int32_t>> pull_sparse_status;
pull_sparse_status.resize(0);
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
auto status = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
}
for (auto& t : pull_sparse_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
exit(-1);
}
}
#endif
}
void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
regions.reserve(var_names.size());
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status));
#endif
}
void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
regions.reserve(var_names.size());
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
status.wait();
#endif
}
void FleetWrapper::PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int count = tensor->numel();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, count);
regions.emplace_back(std::move(reg));
}
auto status = pslib_ptr_->_worker_ptr->push_dense(regions.data(),
regions.size(), table_id);
push_sparse_status->push_back(std::move(status));
#endif
}
void FleetWrapper::PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys, const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status) {
#ifdef PADDLE_WITH_PSLIB
int offset = 2;
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_key_names[i]);
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = scope.FindVar(sparse_key_names[i]);
CHECK(var != nullptr) << "var[" << sparse_key_names[i] << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
g += emb_dim;
continue;
}
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
g += emb_dim;
fea_idx++;
}
}
CHECK(fea_idx == fea_keys.size()) << "fea_idx: " << fea_idx
<< "features size: " << fea_keys.size();
std::vector<float*> push_g_vec;
for (auto i = 0u; i < fea_keys.size(); ++i) {
push_g_vec.push_back((*push_values)[i].data());
}
auto status = pslib_ptr_->_worker_ptr->push_sparse(
table_id, fea_keys.data(), (const float**)push_g_vec.data(),
fea_keys.size());
push_sparse_status->push_back(std::move(status));
#endif
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#endif
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
// A wrapper class for pslib.h, this class follows Singleton pattern
// i.e. only initialized once in the current process
// Example:
// std::shared_ptr<FleetWrapper> fleet_ptr =
// FleetWrapper::GetInstance();
// string dist_desc;
// fleet_ptr->InitServer(dist_desc, 0);
// interface design principles:
// Pull
// Sync: PullSparseVarsSync
// Async: PullSparseVarsAsync(not implemented currently)
// Push
// Sync: PushSparseVarsSync
// Async: PushSparseVarsAsync
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status
class FleetWrapper {
public:
FleetWrapper() {}
virtual ~FleetWrapper() {}
// Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_dim);
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PullDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
void PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status);
// Push sparse variables with labels to server in Async mode
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad_names,
// fea_keys, fea_labels, sparse_grad_names
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
/*
void PushSparseVarsAsync(
const Scope& scope,
const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<std::string>& sparse_grad_names,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
*/
void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num,
int index);
void StopServer();
uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
static std::shared_ptr<FleetWrapper> s_instance_;
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
return s_instance_;
}
#ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif
protected:
bool is_initialized_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace paddle {
namespace framework {
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
}
void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
CreateThreadScope(main_prog);
CreateThreadOperators(main_prog);
}
void HogwildWorker::TrainFilesWithProfiler() {
platform::SetNumThreads(1);
thread_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(ops_.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
int cur_batch;
int batch_cnt = 0;
timeline.Start();
while ((cur_batch = thread_reader_->Next()) > 0) {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) {
timeline.Start();
ops_[i]->Run(*thread_scope_, place_);
timeline.Pause();
op_total_time[i] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
++batch_cnt;
thread_scope_->DropKids();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
/*
int fetch_var_num = fetch_var_names_.size();
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
}
*/
}
}
timeline.Start();
}
}
void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1);
// how to accumulate fetched values here
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
++batch_cnt;
thread_scope_->DropKids();
}
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here
workers_.resize(thread_num_);
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]);
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
}
// call only after all resources are set in current trainer
void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) {
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetPlace(place);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->CreateDeviceResource(main_program); // Program
workers_[i]->BindingDataFeedMemory();
}
}
void MultiTrainer::Run() {
for (int thidx = 0; thidx < thread_num_; ++thidx) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
}
}
void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
void PullDenseWorker::Initialize(const TrainerDesc& param) {
running_ = false;
param_ = param.pull_dense_param();
threshold_ = param_.threshold();
thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms();
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
// setup dense variables for each table
int var_num = param_.dense_table(i).dense_value_name_size();
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
}
// setup training version for each table
training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0;
current_version_[tid] = 0;
}
}
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
for (auto& t : *status_vec) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(WARNING) << "Current Pull Dense Thread Failed Times"
<< ++pull_dense_fail_times_;
}
}
int MAX_FAIL_NUM = 20;
if (pull_dense_fail_times_ > MAX_FAIL_NUM) {
LOG(FATAL) << "Pull Dense Failed Times More Than " << MAX_FAIL_NUM
<< " Times";
exit(-1);
}
}
void PullDenseWorker::Stop() {
if (running_) {
running_ = false;
t_.join();
}
}
int PullDenseWorker::Start() {
running_ = true;
t_ = std::thread(&PullDenseWorker::Run, this);
return 0;
}
void PullDenseWorker::Run() {
while (running_) {
pull_dense_status_.resize(0);
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
if (CheckUpdateParam(tid)) {
fleet_ptr_->PullDenseVarsAsync(
*root_scope_, tid, dense_value_names_[tid], &pull_dense_status_);
ResetThreadVersion(tid);
}
}
if (pull_dense_status_.size() != 0) {
Wait(&pull_dense_status_);
}
usleep(sleep_time_ms_ * 1000);
}
}
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
training_versions_[table_id][thread_id]++;
}
bool PullDenseWorker::CheckUpdateParam(uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
auto& version = training_versions_[table_id];
current_version_[table_id] =
*(std::min_element(version.begin(), version.end()));
if (current_version_[table_id] - last_versions_[table_id] < threshold_) {
return false;
}
return true;
}
void PullDenseWorker::ResetThreadVersion(uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
last_versions_[table_id] = current_version_[table_id];
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
namespace paddle {
namespace framework {
class TrainerBase {
public:
TrainerBase() {}
virtual ~TrainerBase() {}
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void Initialize(const TrainerDesc& trainer_desc);
void SetDebug(const bool debug) { debug_ = debug; }
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
virtual void Run() = 0;
virtual void Finalize() = 0;
protected:
Scope* root_scope_;
bool debug_;
};
// general trainer for async execution
// local trainer and distributed trainer are supported
// depends on the assigned device_worker
class MultiTrainer : public TrainerBase {
public:
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
virtual void Run();
virtual void Finalize();
protected:
int thread_num_;
std::vector<std::thread> threads_;
std::vector<std::shared_ptr<DataFeed>> readers_;
std::vector<std::shared_ptr<DeviceWorker>> workers_;
};
class DistMultiTrainer : public MultiTrainer {
public:
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize();
protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "data_feed.proto";
package paddle.framework;
message TrainerDesc {
// class name for create trainer desc
// the matchness of trainer name and device worker name
// will be checked in python API
optional string class_name = 1;
// class name for creating device worker
optional string device_worker_name = 2;
// thread number
optional int32 thread_num = 3;
// if we need to binding cpu
optional bool binding_cpu = 4 [ default = false ];
repeated string filelist = 5;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
optional PullDenseWorkerParameter pull_dense_param = 102;
// datafeed desc
optional DataFeedDesc data_desc = 201;
}
message HogwildWorkerParameter {}
message DownpourWorkerParameter {
repeated TableParameter sparse_table = 1;
repeated TableParameter dense_table = 2;
repeated string skip_ops = 3;
optional string label_var_name = 4;
}
message PullDenseWorkerParameter {
// dense table only and specialized usage
optional int32 threshold = 1 [ default = 1 ];
optional int32 device_num = 2;
optional int32 sleep_time_ms = 3 [ default = 2 ];
repeated TableParameter dense_table = 4;
}
message TableParameter {
// dense table only
optional int64 table_id = 1;
repeated string dense_value_name = 2;
repeated string dense_grad_name = 3;
repeated int32 dense_table_size = 4;
repeated int32 push_dense_wait_times = 5;
// sparse table only
repeated string sparse_key_name = 6;
repeated string sparse_value_name = 7;
repeated string sparse_grad_name = 8;
repeated int32 push_sparse_wait_times = 9;
// sparse table only and specialized usage
optional int32 emb_dim = 10;
optional int32 fea_dim = 11;
optional string label_var_name = 12;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/trainer_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std::string TrainerFactory::TrainerTypeList() {
std::string trainer_types;
for (auto iter = g_trainer_map.begin(); iter != g_trainer_map.end(); ++iter) {
if (iter != g_trainer_map.begin()) {
trainer_types += ", ";
}
trainer_types += iter->first;
}
return trainer_types;
}
std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
std::string trainer_class) {
if (g_trainer_map.count(trainer_class) < 1) {
exit(-1);
}
return g_trainer_map[trainer_class]();
}
REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // namespace framework
} // namespace paddle
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
......
...@@ -18,5 +18,6 @@ limitations under the License. */ ...@@ -18,5 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable *var, proto::VarType::Type var_type); void InitializeVariable(Variable *var, proto::VarType::Type var_type);
}
} } // end namespace framework
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册