未验证 提交 349e82d6 编写于 作者: T Thunderbrook 提交者: GitHub

support general embedding params (#21217)

* general table

* add sparse table
test=develop

* no cvm
test=develop

* add no_cvm
test=develop

* add note
test=develop

* code style
test=develop

* code style
test=develop

* code style
test=develop

* code style
test=develop

* code style
test=develop

* add key of optimizer
test=develop
上级 3cb6c0a0
......@@ -105,7 +105,10 @@ class PullDenseWorker {
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() { use_cvm_ = false; }
DeviceWorker() {
no_cvm_ = true;
use_cvm_ = false;
}
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
......@@ -135,6 +138,7 @@ class DeviceWorker {
int64_t batch_num_;
FetchConfig fetch_config_;
bool use_cvm_;
bool no_cvm_;
};
class CPUWorkerBase : public DeviceWorker {
......
......@@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm();
// for sparse value accessor, embedding only
no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
......@@ -211,6 +213,9 @@ void DownpourWorker::DumpParam() {
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
if (no_cvm_) {
return;
}
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
......@@ -312,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
int nid_ins_index = 0;
for (int index = 0; index < len; ++index) {
if (use_cvm_) {
if (use_cvm_ || no_cvm_) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim());
......@@ -681,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]);
dump_slot_, &sparse_push_keys_[tid], no_cvm_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
......@@ -906,7 +911,7 @@ void DownpourWorker::TrainFiles() {
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]);
dump_slot_, &sparse_push_keys_[tid], no_cvm_);
}
}
......
......@@ -303,7 +303,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys) {
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm) {
#ifdef PADDLE_WITH_PSLIB
int offset = 2;
int slot_offset = 0;
......@@ -314,6 +314,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
offset = 0;
grad_dim = emb_dim - 2;
}
if (no_cvm) {
offset = 0;
grad_dim = emb_dim;
}
if (dump_slot) {
slot_offset = 1;
show_index = 1;
......@@ -370,12 +374,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
}
sparse_push_keys->push_back(ids[id_idx]);
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
if (use_cvm) {
if (use_cvm || no_cvm) {
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim);
} else {
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][show_index] = 1.0f;
......
......@@ -124,7 +124,7 @@ class FleetWrapper {
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys);
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
......
......@@ -46,6 +46,7 @@ message TrainerDesc {
optional CopyTableConfig copy_table_config = 19;
// adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
optional bool no_cvm = 21 [ default = false ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......
......@@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker):
.sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
i].slot_gradient)
if opt_info["use_cvm"]:
if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[
"no_cvm"] == True:
sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim
......
......@@ -80,7 +80,8 @@ class DownpourServer(Server):
'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num']
'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer']
for key in strategy:
if key not in support_sparse_key_list:
......@@ -108,9 +109,13 @@ class DownpourServer(Server):
table.compress_in_save = strategy.get('sparse_compress_in_save',
True)
table.shard_num = strategy.get('sparse_shard_num', 1000)
# DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
# DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info
# DownpourSparseValueAccessor : for general task, has embedding and sgd info
support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor'
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor'
]
if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class')
......@@ -169,6 +174,69 @@ class DownpourServer(Server):
table1.converter = converter
table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2
table2.converter = converter
table2.deconverter = deconverter
elif accessor_class == 'DownpourSparseValueAccessor':
optimizer_name = strategy.get("sparse_optimizer", "adam")
table.accessor.sparse_commonsgd_param.name = optimizer_name
table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
table.accessor.fea_dim = int(table.accessor.embedx_dim)
if optimizer_name == "naive":
table.accessor.sparse_commonsgd_param.naive.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.naive.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adagrad":
table.accessor.sparse_commonsgd_param.adagrad.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.adagrad.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adagrad.initial_g2sum = strategy.get(
'sparse_initial_g2sum', 3)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adam":
table.accessor.sparse_commonsgd_param.adam.learning_rate = \
strategy.get('sparse_learning_rate', 0.001)
table.accessor.sparse_commonsgd_param.adam.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adam.beta1_decay_rate = strategy.get(
'sparse_beta1_decay_rate', 0.9)
table.accessor.sparse_commonsgd_param.adam.beta2_decay_rate = strategy.get(
'sparse_beta2_decay_rate', 0.999)
table.accessor.sparse_commonsgd_param.adam.ada_epsilon = strategy.get(
'sparse_ada_epsilon', 1e-8)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
converter = strategy.get(
'sparse_converter',
"(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)")
deconverter = strategy.get(
'sparse_deconverter',
"(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
)
table1 = table.accessor.table_accessor_save_param.add()
table1.param = 1
table1.converter = converter
table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2
table2.converter = converter
......
......@@ -367,6 +367,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["no_cvm"] = strategy.get("no_cvm", False)
opt_info["stat_var_names"] = strategy.get("stat_var_names", [])
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names",
......
......@@ -76,6 +76,9 @@ class TrainerDesc(object):
def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm
def _set_no_cvm(self, no_cvm=False):
self.proto_desc.no_cvm = no_cvm
def _set_scale_datanorm(self, scale_datanorm=-1):
self.proto_desc.scale_datanorm = scale_datanorm
......
......@@ -52,6 +52,8 @@ class TrainerFactory(object):
trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("no_cvm") is not None:
trainer._set_no_cvm(opt_info["no_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册