未验证 提交 349e82d6 编写于 作者: T Thunderbrook 提交者: GitHub

support general embedding params (#21217)

* general table

* add sparse table
test=develop

* no cvm
test=develop

* add no_cvm
test=develop

* add note
test=develop

* code style
test=develop

* code style
test=develop

* code style
test=develop

* code style
test=develop

* code style
test=develop

* add key of optimizer
test=develop
上级 3cb6c0a0
...@@ -105,7 +105,10 @@ class PullDenseWorker { ...@@ -105,7 +105,10 @@ class PullDenseWorker {
// should incorporate different type of device // should incorporate different type of device
class DeviceWorker { class DeviceWorker {
public: public:
DeviceWorker() { use_cvm_ = false; } DeviceWorker() {
no_cvm_ = true;
use_cvm_ = false;
}
virtual ~DeviceWorker() {} virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0; virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0; virtual void SetDeviceIndex(int tid) = 0;
...@@ -135,6 +138,7 @@ class DeviceWorker { ...@@ -135,6 +138,7 @@ class DeviceWorker {
int64_t batch_num_; int64_t batch_num_;
FetchConfig fetch_config_; FetchConfig fetch_config_;
bool use_cvm_; bool use_cvm_;
bool no_cvm_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
......
...@@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
// for sparse value accessor, embedding only
no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm(); scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot(); dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size()); dump_fields_.resize(desc.dump_fields_size());
...@@ -211,6 +213,9 @@ void DownpourWorker::DumpParam() { ...@@ -211,6 +213,9 @@ void DownpourWorker::DumpParam() {
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
if (no_cvm_) {
return;
}
uint64_t table_id = static_cast<uint64_t>( uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx)); param_.program_config(0).pull_sparse_table_id(table_idx));
...@@ -312,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -312,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
int nid_ins_index = 0; int nid_ins_index = 0;
for (int index = 0; index < len; ++index) { for (int index = 0; index < len; ++index) {
if (use_cvm_) { if (use_cvm_ || no_cvm_) {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(), memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
...@@ -681,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -681,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]); dump_slot_, &sparse_push_keys_[tid], no_cvm_);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -906,7 +911,7 @@ void DownpourWorker::TrainFiles() { ...@@ -906,7 +911,7 @@ void DownpourWorker::TrainFiles() {
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]); dump_slot_, &sparse_push_keys_[tid], no_cvm_);
} }
} }
......
...@@ -303,7 +303,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -303,7 +303,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot, const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys) { std::vector<uint64_t>* sparse_push_keys, const bool no_cvm) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int slot_offset = 0; int slot_offset = 0;
...@@ -314,6 +314,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -314,6 +314,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
offset = 0; offset = 0;
grad_dim = emb_dim - 2; grad_dim = emb_dim - 2;
} }
if (no_cvm) {
offset = 0;
grad_dim = emb_dim;
}
if (dump_slot) { if (dump_slot) {
slot_offset = 1; slot_offset = 1;
show_index = 1; show_index = 1;
...@@ -370,12 +374,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -370,12 +374,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
sparse_push_keys->push_back(ids[id_idx]); sparse_push_keys->push_back(ids[id_idx]);
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
if (use_cvm) { if (use_cvm || no_cvm) {
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g, memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
} else { } else {
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g, memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
(*push_values)[fea_idx][show_index] = 1.0f; (*push_values)[fea_idx][show_index] = 1.0f;
......
...@@ -124,7 +124,7 @@ class FleetWrapper { ...@@ -124,7 +124,7 @@ class FleetWrapper {
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot, const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys); std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
......
...@@ -46,6 +46,7 @@ message TrainerDesc { ...@@ -46,6 +46,7 @@ message TrainerDesc {
optional CopyTableConfig copy_table_config = 19; optional CopyTableConfig copy_table_config = 19;
// adjust ins weight // adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 20; optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
optional bool no_cvm = 21 [ default = false ];
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker): ...@@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker):
.sparse_table[i].slot_value) .sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[ sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
i].slot_gradient) i].slot_gradient)
if opt_info["use_cvm"]: if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[
"no_cvm"] == True:
sparse_table.emb_dim = \ sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim i].accessor.fea_dim
......
...@@ -80,7 +80,8 @@ class DownpourServer(Server): ...@@ -80,7 +80,8 @@ class DownpourServer(Server):
'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \ 'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \ 'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \ 'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num'] 'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer']
for key in strategy: for key in strategy:
if key not in support_sparse_key_list: if key not in support_sparse_key_list:
...@@ -108,9 +109,13 @@ class DownpourServer(Server): ...@@ -108,9 +109,13 @@ class DownpourServer(Server):
table.compress_in_save = strategy.get('sparse_compress_in_save', table.compress_in_save = strategy.get('sparse_compress_in_save',
True) True)
table.shard_num = strategy.get('sparse_shard_num', 1000) table.shard_num = strategy.get('sparse_shard_num', 1000)
# DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
# DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info
# DownpourSparseValueAccessor : for general task, has embedding and sgd info
support_accessor_class = [ support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor' 'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor'
] ]
if strategy.get('sparse_accessor_class') is not None: if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class') accessor_class = strategy.get('sparse_accessor_class')
...@@ -169,6 +174,69 @@ class DownpourServer(Server): ...@@ -169,6 +174,69 @@ class DownpourServer(Server):
table1.converter = converter table1.converter = converter
table1.deconverter = deconverter table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2
table2.converter = converter
table2.deconverter = deconverter
elif accessor_class == 'DownpourSparseValueAccessor':
optimizer_name = strategy.get("sparse_optimizer", "adam")
table.accessor.sparse_commonsgd_param.name = optimizer_name
table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
table.accessor.fea_dim = int(table.accessor.embedx_dim)
if optimizer_name == "naive":
table.accessor.sparse_commonsgd_param.naive.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.naive.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adagrad":
table.accessor.sparse_commonsgd_param.adagrad.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.adagrad.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adagrad.initial_g2sum = strategy.get(
'sparse_initial_g2sum', 3)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adam":
table.accessor.sparse_commonsgd_param.adam.learning_rate = \
strategy.get('sparse_learning_rate', 0.001)
table.accessor.sparse_commonsgd_param.adam.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adam.beta1_decay_rate = strategy.get(
'sparse_beta1_decay_rate', 0.9)
table.accessor.sparse_commonsgd_param.adam.beta2_decay_rate = strategy.get(
'sparse_beta2_decay_rate', 0.999)
table.accessor.sparse_commonsgd_param.adam.ada_epsilon = strategy.get(
'sparse_ada_epsilon', 1e-8)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
converter = strategy.get(
'sparse_converter',
"(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)")
deconverter = strategy.get(
'sparse_deconverter',
"(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
)
table1 = table.accessor.table_accessor_save_param.add()
table1.param = 1
table1.converter = converter
table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add() table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2 table2.param = 2
table2.converter = converter table2.converter = converter
......
...@@ -367,6 +367,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -367,6 +367,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["fleet_desc"] = ps_param opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = strategy.get("use_cvm", False) opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["no_cvm"] = strategy.get("no_cvm", False)
opt_info["stat_var_names"] = strategy.get("stat_var_names", []) opt_info["stat_var_names"] = strategy.get("stat_var_names", [])
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1) opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names", opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names",
......
...@@ -76,6 +76,9 @@ class TrainerDesc(object): ...@@ -76,6 +76,9 @@ class TrainerDesc(object):
def _set_use_cvm(self, use_cvm=False): def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm self.proto_desc.use_cvm = use_cvm
def _set_no_cvm(self, no_cvm=False):
self.proto_desc.no_cvm = no_cvm
def _set_scale_datanorm(self, scale_datanorm=-1): def _set_scale_datanorm(self, scale_datanorm=-1):
self.proto_desc.scale_datanorm = scale_datanorm self.proto_desc.scale_datanorm = scale_datanorm
......
...@@ -52,6 +52,8 @@ class TrainerFactory(object): ...@@ -52,6 +52,8 @@ class TrainerFactory(object):
trainer._set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None: if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"]) trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("no_cvm") is not None:
trainer._set_no_cvm(opt_info["no_cvm"])
if opt_info.get("scale_datanorm") is not None: if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"]) trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None: if opt_info.get("dump_slot") is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册