未验证 提交 893ea7e0 编写于 作者: T Thunderbrook 提交者: GitHub

[cherry-pick] find lookup table in order & support dump param (#21347)

* support dump param of model into afs (#20302)

* support dump param to afs
test=develop

* code style
test=develop

* code style
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop

* find lookup table in order (#20932)

test=develop

* cherry-pick
test=develop

* solve pslib core in stop worker
test=develop

* print table stat info for pslib
test=develop
上级 5dbe9e59
...@@ -105,7 +105,10 @@ class PullDenseWorker { ...@@ -105,7 +105,10 @@ class PullDenseWorker {
// should incorporate different type of device // should incorporate different type of device
class DeviceWorker { class DeviceWorker {
public: public:
DeviceWorker() { use_cvm_ = false; } DeviceWorker() {
no_cvm_ = true;
use_cvm_ = false;
}
virtual ~DeviceWorker() {} virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0; virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0; virtual void SetDeviceIndex(int tid) = 0;
...@@ -135,6 +138,7 @@ class DeviceWorker { ...@@ -135,6 +138,7 @@ class DeviceWorker {
int64_t batch_num_; int64_t batch_num_;
FetchConfig fetch_config_; FetchConfig fetch_config_;
bool use_cvm_; bool use_cvm_;
bool no_cvm_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
...@@ -203,6 +207,8 @@ class DownpourWorker : public HogwildWorker { ...@@ -203,6 +207,8 @@ class DownpourWorker : public HogwildWorker {
void CopyDenseVars(); void CopyDenseVars();
private: private:
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_; bool need_to_push_dense_;
bool need_dump_field_; bool need_dump_field_;
bool dump_slot_; bool dump_slot_;
......
...@@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
// for sparse value accessor, embedding only
no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm(); scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot(); dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size()); dump_fields_.resize(desc.dump_fields_size());
...@@ -82,6 +84,14 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -82,6 +84,14 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
dump_fields_[i] = desc.dump_fields(i); dump_fields_[i] = desc.dump_fields(i);
} }
adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i)); check_nan_var_names_.push_back(desc.check_nan_var_names(i));
} }
...@@ -186,7 +196,26 @@ bool CheckValidOutput(LoDTensor* tensor, int batch_size) { ...@@ -186,7 +196,26 @@ bool CheckValidOutput(LoDTensor* tensor, int batch_size) {
return true; return true;
} }
void DownpourWorker::DumpParam() {
std::string os;
for (auto& param : dump_param_) {
os.clear();
os = param;
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t len = tensor->numel();
os += PrintLodTensor(tensor, 0, len);
writer_ << os;
}
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
if (no_cvm_) {
return;
}
uint64_t table_id = static_cast<uint64_t>( uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx)); param_.program_config(0).pull_sparse_table_id(table_idx));
...@@ -288,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -288,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
int nid_ins_index = 0; int nid_ins_index = 0;
for (int index = 0; index < len; ++index) { for (int index = 0; index < len; ++index) {
if (use_cvm_) { if (use_cvm_ || no_cvm_) {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(), memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
...@@ -657,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -657,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]); dump_slot_, &sparse_push_keys_[tid], no_cvm_);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -882,7 +911,7 @@ void DownpourWorker::TrainFiles() { ...@@ -882,7 +911,7 @@ void DownpourWorker::TrainFiles() {
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]); dump_slot_, &sparse_push_keys_[tid], no_cvm_);
} }
} }
...@@ -977,6 +1006,9 @@ void DownpourWorker::TrainFiles() { ...@@ -977,6 +1006,9 @@ void DownpourWorker::TrainFiles() {
} }
writer_ << ars[i]; writer_ << ars[i];
} }
if (need_dump_param_ && thread_id_ == 0) {
DumpParam();
}
} }
PrintFetchVars(); PrintFetchVars();
......
...@@ -91,6 +91,13 @@ void FleetWrapper::StopServer() { ...@@ -91,6 +91,13 @@ void FleetWrapper::StopServer() {
#endif #endif
} }
void FleetWrapper::FinalizeWorker() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to finalize worker";
pslib_ptr_->finalize_worker();
#endif
}
uint64_t FleetWrapper::RunServer() { uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to run server"; VLOG(3) << "Going to run server";
...@@ -303,7 +310,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -303,7 +310,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot, const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys) { std::vector<uint64_t>* sparse_push_keys, const bool no_cvm) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int slot_offset = 0; int slot_offset = 0;
...@@ -314,6 +321,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -314,6 +321,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
offset = 0; offset = 0;
grad_dim = emb_dim - 2; grad_dim = emb_dim - 2;
} }
if (no_cvm) {
offset = 0;
grad_dim = emb_dim;
}
if (dump_slot) { if (dump_slot) {
slot_offset = 1; slot_offset = 1;
show_index = 1; show_index = 1;
...@@ -370,12 +381,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -370,12 +381,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
sparse_push_keys->push_back(ids[id_idx]); sparse_push_keys->push_back(ids[id_idx]);
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
if (use_cvm) { if (use_cvm || no_cvm) {
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g, memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
} else { } else {
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g, memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
(*push_values)[fea_idx][show_index] = 1.0f; (*push_values)[fea_idx][show_index] = 1.0f;
...@@ -549,6 +560,19 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { ...@@ -549,6 +560,19 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#endif #endif
} }
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
LOG(ERROR) << "print table stat failed";
}
#else
VLOG(0) << "FleetWrapper::PrintTableStat does nothing when no pslib";
#endif
}
double FleetWrapper::GetCacheThreshold(int table_id) { double FleetWrapper::GetCacheThreshold(int table_id) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
double cache_threshold = 0.0; double cache_threshold = 0.0;
......
...@@ -124,7 +124,7 @@ class FleetWrapper { ...@@ -124,7 +124,7 @@ class FleetWrapper {
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot, const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys); std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
...@@ -147,6 +147,8 @@ class FleetWrapper { ...@@ -147,6 +147,8 @@ class FleetWrapper {
int index); int index);
// stop server // stop server
void StopServer(); void StopServer();
// finalize worker to make worker can be stop
void FinalizeWorker();
// run server // run server
uint64_t RunServer(); uint64_t RunServer();
// gather server ip // gather server ip
...@@ -165,6 +167,8 @@ class FleetWrapper { ...@@ -165,6 +167,8 @@ class FleetWrapper {
std::string model_path, std::string model_proto_file, std::string model_path, std::string model_proto_file,
std::vector<std::string> table_var_list, std::vector<std::string> table_var_list,
bool load_combine); bool load_combine);
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff // mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode); void LoadModel(const std::string& path, const int mode);
......
...@@ -105,7 +105,6 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -105,7 +105,6 @@ class DistMultiTrainer : public MultiTrainer {
bool need_dump_field_; bool need_dump_field_;
std::string dump_fields_path_; std::string dump_fields_path_;
std::string dump_converter_; std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_; int mpi_rank_;
int mpi_size_; int mpi_size_;
int dump_file_num_; int dump_file_num_;
......
...@@ -40,12 +40,14 @@ message TrainerDesc { ...@@ -40,12 +40,14 @@ message TrainerDesc {
repeated string dump_fields = 13; repeated string dump_fields = 13;
optional string dump_converter = 14; optional string dump_converter = 14;
repeated string dump_param = 15; repeated string dump_param = 15;
optional int32 mpi_size = 16 [ default = -1 ]; optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ]; optional int32 dump_file_num = 17 [ default = 16 ];
repeated string check_nan_var_names = 18; repeated string check_nan_var_names = 18;
optional CopyTableConfig copy_table_config = 19; optional CopyTableConfig copy_table_config = 19;
// adjust ins weight // adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 20; optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
optional bool no_cvm = 21 [ default = false ];
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -55,6 +55,7 @@ void BindFleetWrapper(py::module* m) { ...@@ -55,6 +55,7 @@ void BindFleetWrapper(py::module* m) {
.def("load_model", &framework::FleetWrapper::LoadModel) .def("load_model", &framework::FleetWrapper::LoadModel)
.def("clear_model", &framework::FleetWrapper::ClearModel) .def("clear_model", &framework::FleetWrapper::ClearModel)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
.def("finalize_worker", &framework::FleetWrapper::FinalizeWorker)
.def("gather_servers", &framework::FleetWrapper::GatherServers) .def("gather_servers", &framework::FleetWrapper::GatherServers)
.def("gather_clients", &framework::FleetWrapper::GatherClients) .def("gather_clients", &framework::FleetWrapper::GatherClients)
.def("get_clients_info", &framework::FleetWrapper::GetClientsInfo) .def("get_clients_info", &framework::FleetWrapper::GetClientsInfo)
...@@ -62,6 +63,7 @@ void BindFleetWrapper(py::module* m) { ...@@ -62,6 +63,7 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::CreateClient2ClientConnection) &framework::FleetWrapper::CreateClient2ClientConnection)
.def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable) .def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable)
.def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable) .def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable)
.def("print_table_stat", &framework::FleetWrapper::PrintTableStat)
.def("client_flush", &framework::FleetWrapper::ClientFlush) .def("client_flush", &framework::FleetWrapper::ClientFlush)
.def("load_from_paddle_model", .def("load_from_paddle_model",
&framework::FleetWrapper::LoadFromPaddleModel) &framework::FleetWrapper::LoadFromPaddleModel)
......
...@@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker): ...@@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker):
.sparse_table[i].slot_value) .sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[ sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
i].slot_gradient) i].slot_gradient)
if opt_info["use_cvm"]: if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[
"no_cvm"] == True:
sparse_table.emb_dim = \ sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim i].accessor.fea_dim
......
...@@ -182,6 +182,10 @@ class PSLib(Fleet): ...@@ -182,6 +182,10 @@ class PSLib(Fleet):
destroyed when stop() is called. destroyed when stop() is called.
""" """
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
# all worker should be finalize first
if self._role_maker.is_worker():
self._fleet_ptr.finalize_worker()
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
self._fleet_ptr.stop_server() self._fleet_ptr.stop_server()
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
...@@ -234,6 +238,25 @@ class PSLib(Fleet): ...@@ -234,6 +238,25 @@ class PSLib(Fleet):
""" """
self._fleet_ptr.save_model(dirname) self._fleet_ptr.save_model(dirname)
def print_table_stat(self, table_id):
"""
print stat info of table_id,
format: tableid, feasign size, mf size
Args:
table_id(int): the id of table
Example:
.. code-block:: python
fleet.print_table_stat(0)
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.print_table_stat(table_id)
self._role_maker._barrier_worker()
def save_persistables(self, executor, dirname, main_program=None, **kwargs): def save_persistables(self, executor, dirname, main_program=None, **kwargs):
""" """
save presistable parameters, save presistable parameters,
......
...@@ -80,7 +80,8 @@ class DownpourServer(Server): ...@@ -80,7 +80,8 @@ class DownpourServer(Server):
'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \ 'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \ 'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \ 'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num'] 'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer']
for key in strategy: for key in strategy:
if key not in support_sparse_key_list: if key not in support_sparse_key_list:
...@@ -108,9 +109,13 @@ class DownpourServer(Server): ...@@ -108,9 +109,13 @@ class DownpourServer(Server):
table.compress_in_save = strategy.get('sparse_compress_in_save', table.compress_in_save = strategy.get('sparse_compress_in_save',
True) True)
table.shard_num = strategy.get('sparse_shard_num', 1000) table.shard_num = strategy.get('sparse_shard_num', 1000)
# DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
# DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info
# DownpourSparseValueAccessor : for general task, has embedding and sgd info
support_accessor_class = [ support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor' 'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor'
] ]
if strategy.get('sparse_accessor_class') is not None: if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class') accessor_class = strategy.get('sparse_accessor_class')
...@@ -169,6 +174,69 @@ class DownpourServer(Server): ...@@ -169,6 +174,69 @@ class DownpourServer(Server):
table1.converter = converter table1.converter = converter
table1.deconverter = deconverter table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2
table2.converter = converter
table2.deconverter = deconverter
elif accessor_class == 'DownpourSparseValueAccessor':
optimizer_name = strategy.get("sparse_optimizer", "adam")
table.accessor.sparse_commonsgd_param.name = optimizer_name
table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
table.accessor.fea_dim = int(table.accessor.embedx_dim)
if optimizer_name == "naive":
table.accessor.sparse_commonsgd_param.naive.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.naive.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adagrad":
table.accessor.sparse_commonsgd_param.adagrad.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.adagrad.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adagrad.initial_g2sum = strategy.get(
'sparse_initial_g2sum', 3)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adam":
table.accessor.sparse_commonsgd_param.adam.learning_rate = \
strategy.get('sparse_learning_rate', 0.001)
table.accessor.sparse_commonsgd_param.adam.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adam.beta1_decay_rate = strategy.get(
'sparse_beta1_decay_rate', 0.9)
table.accessor.sparse_commonsgd_param.adam.beta2_decay_rate = strategy.get(
'sparse_beta2_decay_rate', 0.999)
table.accessor.sparse_commonsgd_param.adam.ada_epsilon = strategy.get(
'sparse_ada_epsilon', 1e-8)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
converter = strategy.get(
'sparse_converter',
"(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)")
deconverter = strategy.get(
'sparse_deconverter',
"(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
)
table1 = table.accessor.table_accessor_save_param.add()
table1.param = 1
table1.converter = converter
table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add() table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2 table2.param = 2
table2.converter = converter table2.converter = converter
......
...@@ -130,13 +130,22 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -130,13 +130,22 @@ class DistributedAdam(DistributedOptimizerImplBase):
find multi-sparse-table find multi-sparse-table
""" """
table_names = set() table_names = set()
cnt = 0
tmp_list = []
ret_list = []
for loss in losses: for loss in losses:
for op in loss.block.program.global_block().ops: for op in loss.block.program.global_block().ops:
if op.type == "lookup_table": if op.type == "lookup_table":
if op.attr('is_distributed') is True: if op.attr('is_distributed') is True:
table_name = op.input("W")[0] table_name = op.input("W")[0]
table_names.add(table_name) if table_name not in table_names:
return list(table_names) table_names.add(table_name)
tmp_list.append([table_name, cnt])
cnt += 1
tmp_list.sort(key=lambda k: k[1])
for x in tmp_list:
ret_list.append(x[0])
return ret_list
def _minimize(self, def _minimize(self,
losses, losses,
...@@ -366,6 +375,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -366,6 +375,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["fleet_desc"] = ps_param opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = strategy.get("use_cvm", False) opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["no_cvm"] = strategy.get("no_cvm", False)
opt_info["stat_var_names"] = strategy.get("stat_var_names", []) opt_info["stat_var_names"] = strategy.get("stat_var_names", [])
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1) opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names", opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names",
...@@ -375,6 +385,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -375,6 +385,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor": 0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True opt_info["dump_slot"] = True
......
...@@ -76,6 +76,9 @@ class TrainerDesc(object): ...@@ -76,6 +76,9 @@ class TrainerDesc(object):
def _set_use_cvm(self, use_cvm=False): def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm self.proto_desc.use_cvm = use_cvm
def _set_no_cvm(self, no_cvm=False):
self.proto_desc.no_cvm = no_cvm
def _set_scale_datanorm(self, scale_datanorm=-1): def _set_scale_datanorm(self, scale_datanorm=-1):
self.proto_desc.scale_datanorm = scale_datanorm self.proto_desc.scale_datanorm = scale_datanorm
...@@ -101,6 +104,10 @@ class TrainerDesc(object): ...@@ -101,6 +104,10 @@ class TrainerDesc(object):
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
def _set_dump_param(self, dump_param):
for param in dump_param:
self.proto_desc.dump_param.append(param)
def _set_check_nan_var_names(self, check_nan_var_names): def _set_check_nan_var_names(self, check_nan_var_names):
for var in check_nan_var_names: for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var) self.proto_desc.check_nan_var_names.append(var)
......
...@@ -52,6 +52,8 @@ class TrainerFactory(object): ...@@ -52,6 +52,8 @@ class TrainerFactory(object):
trainer._set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None: if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"]) trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("no_cvm") is not None:
trainer._set_no_cvm(opt_info["no_cvm"])
if opt_info.get("scale_datanorm") is not None: if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"]) trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None: if opt_info.get("dump_slot") is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册