未验证 提交 9e045170 编写于 作者: X xujiaqi01 提交者: GitHub

add copy table (#21086)

* copy some feasigns and corresponding embeddings from one sparse table to another
* copy all feasigns and corresponding embeddings from one sparse table to another
* copy all dense params from one table to another
* copy some local vars to other local vars
上级 aeb88791
......@@ -21,6 +21,9 @@ limitations under the License. */
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <unordered_map> // NOLINT
#include <unordered_set> // NOLINT
#include <utility> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
......@@ -195,6 +198,9 @@ class DownpourWorker : public HogwildWorker {
void CollectLabelInfo(size_t table_id);
void AdjustInsWeight();
void DumpParam();
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();
private:
bool need_dump_param_;
......@@ -237,6 +243,12 @@ class DownpourWorker : public HogwildWorker {
std::vector<float> nid_show_;
// check nan and inf during training
std::vector<std::string> check_nan_var_names_;
// copy table
CopyTableConfig copy_table_config_;
std::map<uint64_t, uint64_t> table_dependency_;
std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
......@@ -93,6 +93,29 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i));
}
copy_table_config_ = desc.copy_table_config();
for (int i = 0; i < copy_table_config_.src_sparse_tables_size(); ++i) {
uint64_t src_table = copy_table_config_.src_sparse_tables(i);
uint64_t dest_table = copy_table_config_.dest_sparse_tables(i);
VLOG(3) << "copy_sparse_tables_ push back " << src_table << "->"
<< dest_table;
copy_sparse_tables_.push_back(std::make_pair(src_table, dest_table));
}
for (int i = 0; i < copy_table_config_.src_dense_tables_size(); ++i) {
uint64_t src_table = copy_table_config_.src_dense_tables(i);
uint64_t dest_table = copy_table_config_.dest_dense_tables(i);
VLOG(3) << "copy_dense_tables_ push back " << src_table << "->"
<< dest_table;
copy_dense_tables_.push_back(std::make_pair(src_table, dest_table));
}
for (auto& m : copy_table_config_.table_denpendency_map()) {
if (sparse_key_names_.find(m.key()) != sparse_key_names_.end()) {
// currently only support one dependency
for (auto& value : m.values()) {
table_dependency_[m.key()] = value;
}
}
}
}
void DownpourWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
......@@ -404,6 +427,102 @@ void DownpourWorker::AdjustInsWeight() {
#endif
}
void DownpourWorker::CopySparseTable() {
for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) {
int64_t src_table = copy_sparse_tables_[i].first;
int64_t dest_table = copy_sparse_tables_[i].second;
int32_t feanum = 0;
if (src_table == dest_table) {
continue;
} else if (!copy_table_config_.sparse_copy_by_feasign()) {
if (feasign_set_.find(src_table) == feasign_set_.end()) {
continue;
} else if (feasign_set_[src_table].size() == 0) {
continue;
}
feanum = fleet_ptr_->CopyTable(src_table, dest_table);
} else {
std::vector<uint64_t> fea_vec(feasign_set_[src_table].begin(),
feasign_set_[src_table].end());
feanum = fleet_ptr_->CopyTableByFeasign(src_table, dest_table, fea_vec);
fea_vec.clear();
std::vector<uint64_t>().swap(fea_vec);
}
VLOG(3) << "copy feasign from table " << src_table << " to table "
<< dest_table << ", feasign num=" << feanum;
feasign_set_[src_table].clear();
std::unordered_set<uint64_t>().swap(feasign_set_[src_table]);
}
feasign_set_.clear();
}
void DownpourWorker::CopyDenseTable() {
if (thread_id_ != 0) {
return;
}
thread_local std::vector<std::future<int32_t>> pull_dense_status;
for (size_t i = 0; i < copy_dense_tables_.size(); ++i) {
uint64_t src_table = copy_dense_tables_[i].first;
uint64_t dest_table = copy_dense_tables_[i].second;
if (src_table == dest_table) {
continue;
}
int32_t dim = fleet_ptr_->CopyTable(src_table, dest_table);
VLOG(3) << "copy param from table " << src_table << " to table "
<< dest_table << ", dim=" << dim;
if (copy_table_config_.dense_pull_after_copy()) {
VLOG(3) << "dense pull after copy, table=" << dest_table;
pull_dense_status.resize(0);
fleet_ptr_->PullDenseVarsAsync(*root_scope_, dest_table,
dense_value_names_[dest_table],
&pull_dense_status);
for (auto& t : pull_dense_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(WARNING) << "pull dense after copy table failed,"
<< " table=" << dest_table;
}
}
}
}
}
void DownpourWorker::CopyDenseVars() {
if (thread_id_ != 0) {
return;
}
for (int i = 0; i < copy_table_config_.src_var_list_size(); ++i) {
auto& src_var_name = copy_table_config_.src_var_list(i);
auto& dest_var_name = copy_table_config_.dest_var_list(i);
if (src_var_name == dest_var_name) {
continue;
}
VLOG(3) << "copy dense var from " << src_var_name << " to "
<< dest_var_name;
Variable* src_var = thread_scope_->FindVar(src_var_name);
CHECK(src_var != nullptr) << src_var_name << " not found"; // NOLINT
LoDTensor* src_tensor = src_var->GetMutable<LoDTensor>();
CHECK(src_tensor != nullptr) << src_var_name
<< " tensor is null"; // NOLINT
float* src_data = src_tensor->data<float>();
Variable* dest_var = thread_scope_->FindVar(dest_var_name);
CHECK(dest_var != nullptr) << dest_var_name << " not found"; // NOLINT
LoDTensor* dest_tensor = dest_var->GetMutable<LoDTensor>();
CHECK(dest_tensor != nullptr) << dest_var_name
<< " tensor is null"; // NOLINT
float* dest_data = dest_tensor->data<float>();
CHECK(src_tensor->numel() == dest_tensor->numel())
<< "tensor numel not equal," << src_tensor->numel() << " vs "
<< dest_tensor->numel();
for (int i = 0; i < src_tensor->numel(); i++) {
dest_data[i] = src_data[i];
}
}
}
void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "Begin to train files with profiler";
platform::SetNumThreads(1);
......@@ -437,6 +556,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
double fill_sparse_time = 0.0;
double push_sparse_time = 0.0;
double push_dense_time = 0.0;
double copy_table_time = 0.0;
int cur_batch;
int batch_cnt = 0;
uint64_t total_inst = 0;
......@@ -445,6 +565,27 @@ void DownpourWorker::TrainFilesWithProfiler() {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
timeline.Start();
if (copy_table_config_.need_copy()) {
VLOG(3) << "copy_sparse_tables_.size " << copy_sparse_tables_.size();
if (copy_table_config_.sparse_copy_by_feasign()) {
for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) {
uint64_t tid = copy_sparse_tables_[i].first;
feasign_set_[tid].insert(sparse_push_keys_[tid].begin(),
sparse_push_keys_[tid].end());
}
}
if (batch_cnt % copy_table_config_.batch_num() == 0) {
CopySparseTable();
CopyDenseTable();
CopyDenseVars();
}
}
timeline.Pause();
copy_table_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "program config size: " << param_.program_config_size();
for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
......@@ -641,6 +782,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
collect_label_time / batch_cnt);
fprintf(stderr, "adjust ins weight time: %fs\n",
adjust_ins_weight_time / batch_cnt);
fprintf(stderr, "copy table time: %fs\n", copy_table_time / batch_cnt);
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "op run percent: %f\n", op_sum_time / total_time * 100);
......@@ -648,6 +790,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
pull_sparse_time / total_time * 100);
fprintf(stderr, "adjust ins weight time percent: %f\n",
adjust_ins_weight_time / total_time * 100);
fprintf(stderr, "copy table time percent: %f\n",
copy_table_time / total_time * 100);
fprintf(stderr, "collect label time percent: %f\n",
collect_label_time / total_time * 100);
fprintf(stderr, "fill sparse time percent: %f\n",
......@@ -661,6 +805,11 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
timeline.Start();
}
if (copy_table_config_.need_copy()) {
CopySparseTable();
CopyDenseTable();
CopyDenseVars();
}
}
void DownpourWorker::TrainFiles() {
......@@ -670,6 +819,20 @@ void DownpourWorker::TrainFiles() {
int batch_cnt = 0;
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
if (copy_table_config_.need_copy()) {
if (copy_table_config_.sparse_copy_by_feasign()) {
for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) {
uint64_t tid = copy_sparse_tables_[i].first;
feasign_set_[tid].insert(sparse_push_keys_[tid].begin(),
sparse_push_keys_[tid].end());
}
}
if (batch_cnt % copy_table_config_.batch_num() == 0) {
CopySparseTable();
CopyDenseTable();
CopyDenseVars();
}
}
// pull sparse here
for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
......@@ -850,6 +1013,11 @@ void DownpourWorker::TrainFiles() {
if (need_dump_field_) {
writer_.Flush();
}
if (copy_table_config_.need_copy()) {
CopySparseTable();
CopyDenseTable();
CopyDenseVars();
}
}
} // end namespace framework
......
......@@ -40,28 +40,6 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
template <class AR>
paddle::ps::Archive<AR>& operator<<(paddle::ps::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
ar << ins.GetOffset();
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
}
template <class AR>
paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
ar >> ins.MutableOffset();
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
}
#endif
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif
......@@ -729,40 +707,6 @@ std::future<int32_t> FleetWrapper::SendClientToClientMsg(
return std::future<int32_t>();
}
template <typename T>
void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
for (size_t i = 0; i < t.size(); ++i) {
ar << *(t[i]);
}
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
#endif
}
template <typename T>
void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
if (str.length() == 0) {
return;
}
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
if (ar.cursor() == ar.finish()) {
return;
}
while (ar.cursor() < ar.finish()) {
t->push_back(ar.get<T>());
}
CHECK(ar.cursor() == ar.finish());
VLOG(3) << "Deserialize size " << t->size();
#else
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
......@@ -781,10 +725,43 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
return r.engine;
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
std::vector<std::vector<MultiSlotType>>*, const std::string&);
int32_t FleetWrapper::CopyTable(const uint64_t src_table_id,
const uint64_t dest_table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->copy_table(src_table_id, dest_table_id);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "copy table failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::CopyTable does nothing when no pslib";
return 0;
#endif
}
int32_t FleetWrapper::CopyTableByFeasign(
const uint64_t src_table_id, const uint64_t dest_table_id,
const std::vector<uint64_t>& feasign_list) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->copy_table_by_feasign(
src_table_id, dest_table_id, feasign_list.data(), feasign_list.size());
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "copy table by feasign failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::CopyTableByFeasign does nothing when no pslib";
return 0;
#endif
}
} // end namespace framework
} // end namespace paddle
......@@ -67,11 +67,12 @@ class FleetWrapper {
client2client_max_retry_ = 3;
}
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
// Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Pull sparse variables from server in sync mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
......@@ -80,19 +81,24 @@ class FleetWrapper {
int fea_dim,
const std::vector<std::string>& var_emb_names);
// pull dense variables from server in sync mod
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// pull dense variables from server in async mod
// Param<in>: scope, table_id, var_names
// Param<out>: pull_dense_status
void PullDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
// push dense parameters(not gradients) to server in sync mode
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<in>: scope, table_id, var_names, scale_datanorm, batch_size
// Param<out>: push_sparse_status
void PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
......@@ -100,13 +106,14 @@ class FleetWrapper {
std::vector<::std::future<int32_t>>* push_sparse_status,
float scale_datanorm, int batch_size);
// push dense variables to server in sync mode
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push sparse variables with labels to server in Async mode
// Push sparse variables with labels to server in async mode
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad_names,
// fea_keys, fea_labels, sparse_grad_names
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
// sparse_grad_names, batch_size, use_cvm, dump_slot
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
......@@ -132,12 +139,17 @@ class FleetWrapper {
std::vector<::std::future<int32_t>>* push_sparse_status);
*/
// init server
void InitServer(const std::string& dist_desc, int index);
// init trainer
void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num,
int index);
// stop server
void StopServer();
// run server
uint64_t RunServer();
// gather server ip
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
// gather client ip
void GatherClients(const std::vector<uint64_t>& host_sign_list);
......@@ -145,7 +157,6 @@ class FleetWrapper {
std::vector<uint64_t> GetClientsInfo();
// create client to client connection
void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// load from paddle model
......@@ -164,37 +175,42 @@ class FleetWrapper {
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
// get save cache threshold
double GetCacheThreshold();
// shuffle cache model between servers
void CacheShuffle(int table_id, const std::string& path, const int mode,
const double cache_threshold);
// save cache model
// cache model can speed up online predict
int32_t SaveCache(int table_id, const std::string& path, const int mode);
// copy feasign key/value from src_table_id to dest_table_id
int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
// copy feasign key/value from src_table_id to dest_table_id
int32_t CopyTableByFeasign(const uint64_t src_table_id,
const uint64_t dest_table_id,
const std::vector<uint64_t>& feasign_list);
// clear all models, release their memory
void ClearModel();
// shrink sparse table
void ShrinkSparseTable(int table_id);
// shrink dense table
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay,
int emb_dim);
// register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
// register client to client communication
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
template <typename T>
void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T>
void Deserialize(std::vector<T>* t, const std::string& str);
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
return s_instance_;
}
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
......
......@@ -40,10 +40,12 @@ message TrainerDesc {
repeated string dump_fields = 13;
optional string dump_converter = 14;
repeated string dump_param = 15;
optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ];
repeated string check_nan_var_names = 18;
optional CopyTableConfig copy_table_config = 19;
// adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......@@ -52,8 +54,6 @@ message TrainerDesc {
optional SectionWorkerParameter section_param = 104;
// datafeed desc
optional DataFeedDesc data_desc = 201;
// adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 301;
}
message HogwildWorkerParameter { repeated string skip_ops = 1; }
......@@ -108,6 +108,29 @@ message AdjustInsWeightConfig {
optional string ins_weight_slot = 5 [ default = "" ];
}
message TableDependencyMap {
required int32 key = 1;
repeated int32 values = 2;
}
message CopyTableConfig {
optional bool need_copy = 1 [ default = false ];
optional int32 batch_num = 2 [ default = 100 ];
repeated int32 src_sparse_tables = 3;
repeated int32 dest_sparse_tables = 4;
repeated int32 src_dense_tables = 5;
repeated int32 dest_dense_tables = 6;
repeated string src_var_list = 7;
repeated string dest_var_list = 8;
// when dest dense table has no grad, should pull explicitly
optional bool dense_pull_after_copy = 9 [ default = false ];
// copy feasigns or copy the whole table
optional bool sparse_copy_by_feasign = 10 [ default = true ];
// table dependency for pull/push
optional bool enable_dependency = 11 [ default = false ];
repeated TableDependencyMap table_denpendency_map = 12;
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
......
......@@ -67,7 +67,10 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::LoadFromPaddleModel)
.def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable)
.def("set_client2client_config",
&framework::FleetWrapper::SetClient2ClientConfig);
&framework::FleetWrapper::SetClient2ClientConfig)
.def("copy_table", &framework::FleetWrapper::CopyTable)
.def("copy_table_by_feasign",
&framework::FleetWrapper::CopyTableByFeasign);
} // end FleetWrapper
} // end namespace pybind
} // end namespace paddle
......@@ -343,7 +343,7 @@ class DownpourWorker(Worker):
target_table = None
for table in self._worker.sparse_table:
if table.table_id == table_id:
keys = self._worker.sparse_table[table_id].slot_key
keys = table.slot_key
key_names = [var.name for var in sorted_slot_key_vars]
for key_name in key_names:
if key_name not in keys:
......
......@@ -372,6 +372,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True
opt_info["adjust_ins_weight"] = strategy.get("adjust_ins_weight", {})
opt_info["copy_table"] = strategy.get("copy_table", {})
for loss in losses:
loss.block.program._fleet_opt = opt_info
......
......@@ -820,8 +820,9 @@ class FleetUtil(object):
"""
fleet._role_maker._barrier_worker()
if fleet._role_maker.is_first_worker():
tables = fleet._dist_desc.trainer_param.dense_table
prog_id = str(id(program))
tables = fleet._opt_info["program_id_to_worker"][prog_id].\
get_desc().dense_table
prog_conf = fleet._opt_info['program_configs'][prog_id]
prog_tables = {}
for key in prog_conf:
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defination of trainers."""
import sys
from os import path
......@@ -120,6 +121,78 @@ class TrainerDesc(object):
self.proto_desc.adjust_ins_weight_config.ins_weight_slot = \
config_dict.get("ins_weight_slot", "")
def _set_copy_table_config(self, config_dict):
config = self.proto_desc.copy_table_config
config.need_copy = config_dict.get("need_copy", False)
config.batch_num = config_dict.get("batch_num", 100)
src_sparse_tables = config_dict.get("src_sparse_tables", [])
if not isinstance(src_sparse_tables, list):
src_sparse_tables = [src_sparse_tables]
dest_sparse_tables = config_dict.get("dest_sparse_tables", [])
if not isinstance(dest_sparse_tables, list):
dest_sparse_tables = [dest_sparse_tables]
if len(src_sparse_tables) != len(dest_sparse_tables):
raise ValueError(
"len(src_sparse_tables) != len(dest_sparse_tables)," \
" %s vs %s" % (len(src_sparse_tables), \
len(dest_sparse_tables)))
for i in src_sparse_tables:
config.src_sparse_tables.append(i)
for i in dest_sparse_tables:
config.dest_sparse_tables.append(i)
src_dense_tables = config_dict.get("src_dense_tables", [])
if not isinstance(src_dense_tables, list):
src_dense_tables = [src_dense_tables]
dest_dense_tables = config_dict.get("dest_dense_tables", [])
if not isinstance(dest_dense_tables, list):
dest_dense_tables = [dest_dense_tables]
if len(src_dense_tables) != len(dest_dense_tables):
raise ValueError(
"len(src_dense_tables) != len(dest_dense_tables)," \
" %s vs %s" % (len(src_dense_tables), \
len(dest_dense_tables)))
for i in src_dense_tables:
config.src_dense_tables.append(i)
for i in dest_dense_tables:
config.dest_dense_tables.append(i)
# user can also specify dense variables to copy,
# instead of copy dense table
src_var_list = config_dict.get("src_var_list", [])
if not isinstance(src_var_list, list):
src_var_list = [src_var_list]
dest_var_list = config_dict.get("dest_var_list", [])
if not isinstance(dest_var_list, list):
dest_var_list = [dest_var_list]
if len(src_var_list) != len(dest_var_list):
raise ValueError(
"len(src_var_list) != len(dest_var_list), %s vs" \
" %s" % (len(src_var_list), len(dest_var_list)))
for i in src_var_list:
config.src_var_list.append(i)
for i in dest_var_list:
config.dest_var_list.append(i)
dependency_map = config_dict.get("dependency_map", {})
for key in dependency_map:
m = config.table_denpendency_map.add()
m.key = key
values = dependency_map[key]
if not isinstance(values, list):
values = [values]
if len(values) != 1:
raise ValueError("dependency len %s != 1" % len(values))
for value in values:
m.values.append(value)
config.dense_pull_after_copy = \
config_dict.get("dense_pull_after_copy", True)
config.enable_dependency = \
config_dict.get("enable_dependency", False)
config.sparse_copy_by_feasign = \
config_dict.get("sparse_copy_by_feasign", True)
def _desc(self):
from google.protobuf import text_format
return self.proto_desc.SerializeToString()
......@@ -151,6 +224,11 @@ class MultiTrainer(TrainerDesc):
class DistMultiTrainer(TrainerDesc):
"""
Implement of DistMultiTrainer.
It's for Distributed training.
"""
def __init__(self):
super(DistMultiTrainer, self).__init__()
pass
......@@ -170,6 +248,11 @@ class DistMultiTrainer(TrainerDesc):
class PipelineTrainer(TrainerDesc):
"""
Implement of PipelineTrainer.
It's for Pipeline.
"""
def __init__(self):
super(PipelineTrainer, self).__init__()
pass
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defination of TrainerFactory."""
import threading
import time
......@@ -24,6 +25,12 @@ __all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]
class TrainerFactory(object):
"""
Create trainer and device worker.
If opt_info is not None, it will get configs from opt_info,
otherwise create MultiTrainer and Hogwild.
"""
def __init__(self):
pass
......@@ -43,24 +50,44 @@ class TrainerFactory(object):
if "fleet_desc" in opt_info:
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None:
trainer._set_dump_slot(opt_info["dump_slot"])
if opt_info.get("mpi_rank") is not None:
trainer._set_mpi_rank(opt_info["mpi_rank"])
if opt_info.get("mpi_size") is not None:
trainer._set_mpi_size(opt_info["mpi_size"])
if opt_info.get("dump_fields") is not None:
trainer._set_dump_fields(opt_info["dump_fields"])
if opt_info.get("dump_fields_path") is not None:
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("dump_file_num") is not None:
trainer._set_dump_file_num(opt_info["dump_file_num"])
if opt_info.get("dump_converter") is not None:
trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_dump_param(opt_info["dump_param"])
if opt_info.get("adjust_ins_weight") is not None:
trainer._set_adjust_ins_weight(opt_info[
"adjust_ins_weight"])
if opt_info.get("copy_table") is not None:
trainer._set_copy_table_config(opt_info["copy_table"])
if opt_info.get("check_nan_var_names") is not None:
trainer._set_check_nan_var_names(opt_info[
"check_nan_var_names"])
if opt_info.get("dump_param") is not None:
trainer._set_dump_param(opt_info["dump_param"])
trainer._set_device_worker(device_worker)
return trainer
class FetchHandlerMonitor(object):
"""
Defination of FetchHandlerMonitor class,
it's for fetch handler.
"""
def __init__(self, scope, handler):
self.fetch_instance = handler
self.fetch_thread = threading.Thread(
......@@ -69,11 +96,21 @@ class FetchHandlerMonitor(object):
self.running = False
def start(self):
"""
start monitor,
it will start a monitor thread.
"""
self.running = True
self.fetch_thread.setDaemon(True)
self.fetch_thread.start()
def handler_decorator(self, fetch_scope, fetch_handler):
"""
decorator of handler,
Args:
fetch_scope(Scope): fetch scope
fetch_handler(Handler): fetch handler
"""
fetch_target_names = self.fetch_instance.fetch_target_names
period_secs = self.fetch_instance.period_secs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册