未验证 提交 8113c913 编写于 作者: W wangguanqun 提交者: GitHub

double accessor and show_scale (#41943)

* double accessor and show_scale

* double accessor and show_scale

* rename

* fix bug in pslib config

* add unittest
上级 d3f95e5a
...@@ -10,7 +10,7 @@ Table: for param storage and update ...@@ -10,7 +10,7 @@ Table: for param storage and update
ValueAccessor: for pull param and push gradient ValueAccessor: for pull param and push gradient
-----CtrCommonAccessor: pull/push value with show/click, float type -----CtrCommonAccessor: pull/push value with show/click, float type
-----DownpourCtrDoubleAccessor: same as CtrCommonAccessor, other than show/click with double type -----CtrDoubleAccessor: same as CtrCommonAccessor, other than show/click with double type
-----SparseAccessor: used for common embedding, pull value without show/click, push value with show/click -----SparseAccessor: used for common embedding, pull value without show/click, push value with show/click
-----CommMergeAccessor: used for dense table only, for get param dim -----CommMergeAccessor: used for dense table only, for get param dim
......
...@@ -42,8 +42,7 @@ set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUT ...@@ -42,8 +42,7 @@ set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUT
set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto) cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_double_accessor SRCS ctr_double_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(ctr_accessor SRCS ctr_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -35,6 +35,10 @@ int CtrCommonAccessor::Initialize() { ...@@ -35,6 +35,10 @@ int CtrCommonAccessor::Initialize() {
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim(); common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
InitAccessorInfo(); InitAccessorInfo();
return 0; return 0;
} }
...@@ -233,6 +237,11 @@ int32_t CtrCommonAccessor::Update(float** update_values, ...@@ -233,6 +237,11 @@ int32_t CtrCommonAccessor::Update(float** update_values,
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.UnseenDaysIndex()] = 0; update_value[common_feature_value.UnseenDaysIndex()] = 0;
// TODO(zhaocaibei123): add configure show_scale // TODO(zhaocaibei123): add configure show_scale
if (!_show_scale) {
push_show = 1;
}
VLOG(3) << "accessor show scale:" << _show_scale
<< ", push_show:" << push_show;
_embed_sgd_rule->UpdateValue( _embed_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedWIndex(), update_value + common_feature_value.EmbedWIndex(),
update_value + common_feature_value.EmbedG2SumIndex(), update_value + common_feature_value.EmbedG2SumIndex(),
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int DownpourCtrDoubleAccessor::Initialize() { int CtrDoubleAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
...@@ -34,14 +34,18 @@ int DownpourCtrDoubleAccessor::Initialize() { ...@@ -34,14 +34,18 @@ int DownpourCtrDoubleAccessor::Initialize() {
_ssd_unseenday_threshold = _ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold(); _config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
InitAccessorInfo(); InitAccessorInfo();
return 0; return 0;
} }
void DownpourCtrDoubleAccessor::InitAccessorInfo() { void CtrDoubleAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
_accessor_info.dim = DownpourCtrDoubleFeatureValue::Dim(embedx_dim); _accessor_info.dim = CtrDoubleFeatureValue::Dim(embedx_dim);
_accessor_info.size = DownpourCtrDoubleFeatureValue::Size(embedx_dim); _accessor_info.size = CtrDoubleFeatureValue::Size(embedx_dim);
_accessor_info.select_dim = 3 + embedx_dim; _accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float); _accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim; _accessor_info.update_dim = 4 + embedx_dim;
...@@ -49,7 +53,7 @@ void DownpourCtrDoubleAccessor::InitAccessorInfo() { ...@@ -49,7 +53,7 @@ void DownpourCtrDoubleAccessor::InitAccessorInfo() {
_accessor_info.mf_size = (embedx_dim + 1) * sizeof(float); _accessor_info.mf_size = (embedx_dim + 1) * sizeof(float);
} }
bool DownpourCtrDoubleAccessor::Shrink(float* value) { bool CtrDoubleAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); // auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
...@@ -59,38 +63,37 @@ bool DownpourCtrDoubleAccessor::Shrink(float* value) { ...@@ -59,38 +63,37 @@ bool DownpourCtrDoubleAccessor::Shrink(float* value) {
_config.ctr_accessor_param().delete_after_unseen_days(); _config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first // time_decay first
DownpourCtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate; CtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate;
DownpourCtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate; CtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate;
// shrink after // shrink after
auto score = ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), auto score = ShowClickScore(CtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)); CtrDoubleFeatureValue::Click(value));
auto unseen_days = DownpourCtrDoubleFeatureValue::UnseenDays(value); auto unseen_days = CtrDoubleFeatureValue::UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) { if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true; return true;
} }
return false; return false;
} }
bool DownpourCtrDoubleAccessor::save_ssd(float* value) { bool CtrDoubleAccessor::SaveSSD(float* value) {
if (DownpourCtrDoubleFeatureValue::UnseenDays(value) > if (CtrDoubleFeatureValue::UnseenDays(value) > _ssd_unseenday_threshold) {
_ssd_unseenday_threshold) {
return true; return true;
} }
return false; return false;
} }
// bool DownpourCtrDoubleAccessor::save_cache( // bool CtrDoubleAccessor::save_cache(
// float* value, int param, double global_cache_threshold) { // float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), // if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
// DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold // CtrDoubleFeatureValue::Click(value)) >= base_threshold
// && DownpourCtrDoubleFeatureValue::UnseenDays(value) <= // && CtrDoubleFeatureValue::UnseenDays(value) <=
// delta_keep_days) { // delta_keep_days) {
// return DownpourCtrDoubleFeatureValue::Show(value) > // return CtrDoubleFeatureValue::Show(value) >
// global_cache_threshold; // global_cache_threshold;
// } // }
// return false; // return false;
// } // }
bool DownpourCtrDoubleAccessor::Save(float* value, int param) { bool CtrDoubleAccessor::Save(float* value, int param) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -109,14 +112,14 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { ...@@ -109,14 +112,14 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >= CtrDoubleFeatureValue::Click(value)) >=
base_threshold && base_threshold &&
DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold && CtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) { CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
if (param == 2) { if (param == 2) {
DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0; CtrDoubleFeatureValue::DeltaScore(value) = 0;
} }
return true; return true;
} else { } else {
...@@ -125,10 +128,10 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { ...@@ -125,10 +128,10 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
} }
// already decayed in shrink // already decayed in shrink
case 3: { case 3: {
// DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate; // CtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate; // CtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
// DownpourCtrDoubleFeatureValue::UnseenDays(value)++; // CtrDoubleFeatureValue::UnseenDays(value)++;
return true; return true;
} }
default: default:
...@@ -136,7 +139,7 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { ...@@ -136,7 +139,7 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
}; };
} }
void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { void CtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -145,17 +148,17 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -145,17 +148,17 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >= CtrDoubleFeatureValue::Click(value)) >=
base_threshold && base_threshold &&
DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold && CtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) { CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0; CtrDoubleFeatureValue::DeltaScore(value) = 0;
} }
} }
return; return;
case 3: { case 3: {
DownpourCtrDoubleFeatureValue::UnseenDays(value)++; CtrDoubleFeatureValue::UnseenDays(value)++;
} }
return; return;
default: default:
...@@ -163,123 +166,125 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -163,123 +166,125 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
}; };
} }
int32_t DownpourCtrDoubleAccessor::Create(float** values, size_t num) { int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0; value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] = 0; value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()) = 0; *(double*)(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()) = 0; *(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1; value[CtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->InitValue( _embed_sgd_rule->InitValue(
value + DownpourCtrDoubleFeatureValue::EmbedWIndex(), value + CtrDoubleFeatureValue::EmbedWIndex(),
value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()); value + CtrDoubleFeatureValue::EmbedG2SumIndex());
_embedx_sgd_rule->InitValue( _embedx_sgd_rule->InitValue(
value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), value + CtrDoubleFeatureValue::EmbedxWIndex(),
value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(), false); value + CtrDoubleFeatureValue::EmbedxG2SumIndex(), false);
} }
return 0; return 0;
} }
bool DownpourCtrDoubleAccessor::NeedExtendMF(float* value) { bool CtrDoubleAccessor::NeedExtendMF(float* value) {
auto show = auto show = ((double*)(value + CtrDoubleFeatureValue::ShowIndex()))[0];
((double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()))[0]; auto click = ((double*)(value + CtrDoubleFeatureValue::ClickIndex()))[0];
auto click =
((double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()))[0];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() // float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
auto score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() + auto score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff(); click * _config.ctr_accessor_param().click_coeff();
//+ click * _config.ctr_accessor_param().click_coeff(); //+ click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
// from DownpourCtrFeatureValue to DownpourCtrPullValue // from CtrDoubleFeatureValue to CtrDoublePullValue
int32_t DownpourCtrDoubleAccessor::Select(float** select_values, int32_t CtrDoubleAccessor::Select(float** select_values, const float** values,
const float** values, size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item]; float* select_value = select_values[value_item];
float* value = const_cast<float*>(values[value_item]); float* value = const_cast<float*>(values[value_item]);
select_value[DownpourCtrDoublePullValue::ShowIndex()] = select_value[CtrDoublePullValue::ShowIndex()] =
(float)*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()); (float)*(double*)(value + CtrDoubleFeatureValue::ShowIndex());
select_value[DownpourCtrDoublePullValue::ClickIndex()] = select_value[CtrDoublePullValue::ClickIndex()] =
(float)*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()); (float)*(double*)(value + CtrDoubleFeatureValue::ClickIndex());
select_value[DownpourCtrDoublePullValue::EmbedWIndex()] = select_value[CtrDoublePullValue::EmbedWIndex()] =
value[DownpourCtrDoubleFeatureValue::EmbedWIndex()]; value[CtrDoubleFeatureValue::EmbedWIndex()];
memcpy(select_value + DownpourCtrDoublePullValue::EmbedxWIndex(), memcpy(select_value + CtrDoublePullValue::EmbedxWIndex(),
value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), value + CtrDoubleFeatureValue::EmbedxWIndex(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
} }
// from DownpourCtrPushValue to DownpourCtrPushValue // from CtrDoublePushValue to CtrDoublePushValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t DownpourCtrDoubleAccessor::Merge(float** update_values, int32_t CtrDoubleAccessor::Merge(float** update_values,
const float** other_update_values, const float** other_update_values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
size_t total_dim = DownpourCtrDoublePushValue::Dim(embedx_dim); size_t total_dim = CtrDoublePushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item]; const float* other_update_value = other_update_values[value_item];
/**(double*)(update_value + DownpourCtrDoublePushValue::ShowIndex()) += /**(double*)(update_value + CtrDoublePushValue::ShowIndex()) +=
*(double*)(other_update_value + DownpourCtrDoublePushValue::ShowIndex()); *(double*)(other_update_value + CtrDoublePushValue::ShowIndex());
*(double*)(update_value + DownpourCtrDoublePushValue::ClickIndex()) += *(double*)(update_value + CtrDoublePushValue::ClickIndex()) +=
*(double*)(other_update_value + DownpourCtrDoublePushValue::ClickIndex()); *(double*)(other_update_value + CtrDoublePushValue::ClickIndex());
for (auto i = 3u; i < total_dim; ++i) { for (auto i = 3u; i < total_dim; ++i) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
}*/ }*/
for (auto i = 0u; i < total_dim; ++i) { for (auto i = 0u; i < total_dim; ++i) {
if (i != DownpourCtrDoublePushValue::SlotIndex()) { if (i != CtrDoublePushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
} }
} }
return 0; return 0;
} }
// from DownpourCtrPushValue to DownpourCtrFeatureValue // from CtrDoublePushValue to CtrDoubleFeatureValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t DownpourCtrDoubleAccessor::Update(float** update_values, int32_t CtrDoubleAccessor::Update(float** update_values,
const float** push_values, const float** push_values, size_t num) {
size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* push_value = push_values[value_item]; const float* push_value = push_values[value_item];
float push_show = push_value[DownpourCtrDoublePushValue::ShowIndex()]; float push_show = push_value[CtrDoublePushValue::ShowIndex()];
float push_click = push_value[DownpourCtrDoublePushValue::ClickIndex()]; float push_click = push_value[CtrDoublePushValue::ClickIndex()];
float slot = push_value[DownpourCtrDoublePushValue::SlotIndex()]; float slot = push_value[CtrDoublePushValue::SlotIndex()];
*(double*)(update_value + DownpourCtrDoubleFeatureValue::ShowIndex()) += *(double*)(update_value + CtrDoubleFeatureValue::ShowIndex()) +=
(double)push_show; (double)push_show;
*(double*)(update_value + DownpourCtrDoubleFeatureValue::ClickIndex()) += *(double*)(update_value + CtrDoubleFeatureValue::ClickIndex()) +=
(double)push_click; (double)push_click;
update_value[DownpourCtrDoubleFeatureValue::SlotIndex()] = slot; update_value[CtrDoubleFeatureValue::SlotIndex()] = slot;
update_value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] += update_value[CtrDoubleFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + //(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff(); // push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0; update_value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
if (!_show_scale) {
push_show = 1;
}
VLOG(3) << "accessor show scale:" << _show_scale
<< ", push_show:" << push_show;
_embed_sgd_rule->UpdateValue( _embed_sgd_rule->UpdateValue(
update_value + DownpourCtrDoubleFeatureValue::EmbedWIndex(), update_value + CtrDoubleFeatureValue::EmbedWIndex(),
update_value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex(), update_value + CtrDoubleFeatureValue::EmbedG2SumIndex(),
push_value + DownpourCtrDoublePushValue::EmbedGIndex(), push_show); push_value + CtrDoublePushValue::EmbedGIndex(), push_show);
_embedx_sgd_rule->UpdateValue( _embedx_sgd_rule->UpdateValue(
update_value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), update_value + CtrDoubleFeatureValue::EmbedxWIndex(),
update_value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(), update_value + CtrDoubleFeatureValue::EmbedxG2SumIndex(),
push_value + DownpourCtrDoublePushValue::EmbedxGIndex(), push_show); push_value + CtrDoublePushValue::EmbedxGIndex(), push_show);
} }
return 0; return 0;
} }
bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) { bool CtrDoubleAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull // stage == 0, pull
// stage == 1, push // stage == 1, push
if (stage == 0) { if (stage == 0) {
return true; return true;
} else if (stage == 1) { } else if (stage == 1) {
auto show = DownpourCtrDoublePushValue::Show(const_cast<float*>(value)); auto show = CtrDoublePushValue::Show(const_cast<float*>(value));
auto click = DownpourCtrDoublePushValue::Click(const_cast<float*>(value)); auto click = CtrDoublePushValue::Click(const_cast<float*>(value));
auto score = ShowClickScore(show, click); auto score = ShowClickScore(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
...@@ -293,23 +298,22 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) { ...@@ -293,23 +298,22 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) {
return true; return true;
} }
} }
double DownpourCtrDoubleAccessor::ShowClickScore(double show, double click) { double CtrDoubleAccessor::ShowClickScore(double show, double click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); // auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff(); // auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff(); auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff; return (show - click) * nonclk_coeff + click * click_coeff;
} }
std::string DownpourCtrDoubleAccessor::ParseToString(const float* v, std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
int param_size) {
thread_local std::ostringstream os; thread_local std::ostringstream os;
os.clear(); os.clear();
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << (float)((double*)(v + 2))[0] << " " os << v[0] << " " << v[1] << " " << (float)((double*)(v + 2))[0] << " "
<< (float)((double*)(v + 4))[0] << " " << v[6] << " " << v[7] << " " << (float)((double*)(v + 4))[0] << " " << v[6] << " " << v[7] << " "
<< v[8]; << v[8];
auto show = DownpourCtrDoubleFeatureValue::Show(const_cast<float*>(v)); auto show = CtrDoubleFeatureValue::Show(const_cast<float*>(v));
auto click = DownpourCtrDoubleFeatureValue::Click(const_cast<float*>(v)); auto click = CtrDoubleFeatureValue::Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click); auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 9) { if (score >= _config.embedx_threshold() && param_size > 9) {
os << " " << v[9]; os << " " << v[9];
...@@ -319,23 +323,22 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v, ...@@ -319,23 +323,22 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v,
} }
return os.str(); return os.str();
} }
int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str, int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
float data_buff[_accessor_info.dim + 2]; float data_buff[_accessor_info.dim + 2];
float* data_buff_ptr = data_buff; float* data_buff_ptr = data_buff;
_embedx_sgd_rule->InitValue( _embedx_sgd_rule->InitValue(
data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), data_buff_ptr + CtrDoubleFeatureValue::EmbedxWIndex(),
data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex()); data_buff_ptr + CtrDoubleFeatureValue::EmbedxG2SumIndex());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr); auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len; CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
int show_index = DownpourCtrDoubleFeatureValue::ShowIndex(); int show_index = CtrDoubleFeatureValue::ShowIndex();
int click_index = DownpourCtrDoubleFeatureValue::ClickIndex(); int click_index = CtrDoubleFeatureValue::ClickIndex();
int embed_w_index = DownpourCtrDoubleFeatureValue::EmbedWIndex(); int embed_w_index = CtrDoubleFeatureValue::EmbedWIndex();
// no slot, embedx // no slot, embedx
int value_dim = _accessor_info.dim; int value_dim = _accessor_info.dim;
int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(); int embedx_g2sum_index = CtrDoubleFeatureValue::EmbedxG2SumIndex();
value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1; value[CtrDoubleFeatureValue::SlotIndex()] = -1;
// other case // other case
if (str_len == (value_dim - 1)) { if (str_len == (value_dim - 1)) {
// copy unseen_days..delta_score // copy unseen_days..delta_score
...@@ -344,8 +347,8 @@ int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str, ...@@ -344,8 +347,8 @@ int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str,
*(double*)(value + show_index) = (double)data_buff_ptr[2]; *(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3]; *(double*)(value + click_index) = (double)data_buff_ptr[3];
// copy others // copy others
value[DownpourCtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4]; value[CtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4];
value[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5]; value[CtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5];
memcpy(value + embedx_g2sum_index, data_buff_ptr + 6, memcpy(value + embedx_g2sum_index, data_buff_ptr + 6,
(embedx_dim + 1) * sizeof(float)); (embedx_dim + 1) * sizeof(float));
} else { } else {
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class DownpourCtrDoubleAccessor : public ValueAccessor { class CtrDoubleAccessor : public ValueAccessor {
public: public:
struct DownpourCtrDoubleFeatureValue { struct CtrDoubleFeatureValue {
/* /*
float unseen_days; float unseen_days;
float delta_score; float delta_score;
...@@ -45,60 +45,56 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -45,60 +45,56 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
} }
static int UnseenDaysIndex() { return 0; } static int UnseenDaysIndex() { return 0; }
static int DeltaScoreIndex() { static int DeltaScoreIndex() {
return DownpourCtrDoubleFeatureValue::UnseenDaysIndex() + 1; return CtrDoubleFeatureValue::UnseenDaysIndex() + 1;
} }
static int ShowIndex() { static int ShowIndex() {
return DownpourCtrDoubleFeatureValue::DeltaScoreIndex() + 1; return CtrDoubleFeatureValue::DeltaScoreIndex() + 1;
} }
// show is double // show is double
static int ClickIndex() { static int ClickIndex() { return CtrDoubleFeatureValue::ShowIndex() + 2; }
return DownpourCtrDoubleFeatureValue::ShowIndex() + 2;
}
// click is double // click is double
static int EmbedWIndex() { static int EmbedWIndex() { return CtrDoubleFeatureValue::ClickIndex() + 2; }
return DownpourCtrDoubleFeatureValue::ClickIndex() + 2;
}
static int EmbedG2SumIndex() { static int EmbedG2SumIndex() {
return DownpourCtrDoubleFeatureValue::EmbedWIndex() + 1; return CtrDoubleFeatureValue::EmbedWIndex() + 1;
} }
static int SlotIndex() { static int SlotIndex() {
return DownpourCtrDoubleFeatureValue::EmbedG2SumIndex() + 1; return CtrDoubleFeatureValue::EmbedG2SumIndex() + 1;
} }
static int EmbedxG2SumIndex() { static int EmbedxG2SumIndex() {
return DownpourCtrDoubleFeatureValue::SlotIndex() + 1; return CtrDoubleFeatureValue::SlotIndex() + 1;
} }
static int EmbedxWIndex() { static int EmbedxWIndex() {
return DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex() + 1; return CtrDoubleFeatureValue::EmbedxG2SumIndex() + 1;
} }
static float& UnseenDays(float* val) { static float& UnseenDays(float* val) {
return val[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()]; return val[CtrDoubleFeatureValue::UnseenDaysIndex()];
} }
static float& DeltaScore(float* val) { static float& DeltaScore(float* val) {
return val[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()]; return val[CtrDoubleFeatureValue::DeltaScoreIndex()];
} }
static double& Show(float* val) { static double& Show(float* val) {
return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0]; return ((double*)(val + CtrDoubleFeatureValue::ShowIndex()))[0];
} }
static double& Click(float* val) { static double& Click(float* val) {
return ((double*)(val + DownpourCtrDoubleFeatureValue::ClickIndex()))[0]; return ((double*)(val + CtrDoubleFeatureValue::ClickIndex()))[0];
} }
static float& Slot(float* val) { static float& Slot(float* val) {
return val[DownpourCtrDoubleFeatureValue::SlotIndex()]; return val[CtrDoubleFeatureValue::SlotIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrDoubleFeatureValue::EmbedWIndex()]; return val[CtrDoubleFeatureValue::EmbedWIndex()];
} }
static float& EmbedG2Sum(float* val) { static float& EmbedG2Sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()]; return val[CtrDoubleFeatureValue::EmbedG2SumIndex()];
} }
static float& EmbedxG2Sum(float* val) { static float& EmbedxG2Sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex()]; return val[CtrDoubleFeatureValue::EmbedxG2SumIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return (val + DownpourCtrDoubleFeatureValue::EmbedxWIndex()); return (val + CtrDoubleFeatureValue::EmbedxWIndex());
} }
}; };
struct DownpourCtrDoublePushValue { struct CtrDoublePushValue {
/* /*
float slot; float slot;
float show; float show;
...@@ -110,35 +106,27 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -110,35 +106,27 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int DimSize(int dim, int embedx_dim) { return sizeof(float); } static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int SlotIndex() { return 0; } static int SlotIndex() { return 0; }
static int ShowIndex() { static int ShowIndex() { return CtrDoublePushValue::SlotIndex() + 1; }
return DownpourCtrDoublePushValue::SlotIndex() + 1; static int ClickIndex() { return CtrDoublePushValue::ShowIndex() + 1; }
} static int EmbedGIndex() { return CtrDoublePushValue::ClickIndex() + 1; }
static int ClickIndex() { static int EmbedxGIndex() { return CtrDoublePushValue::EmbedGIndex() + 1; }
return DownpourCtrDoublePushValue::ShowIndex() + 1;
}
static int EmbedGIndex() {
return DownpourCtrDoublePushValue::ClickIndex() + 1;
}
static int EmbedxGIndex() {
return DownpourCtrDoublePushValue::EmbedGIndex() + 1;
}
static float& Slot(float* val) { static float& Slot(float* val) {
return val[DownpourCtrDoublePushValue::SlotIndex()]; return val[CtrDoublePushValue::SlotIndex()];
} }
static float& Show(float* val) { static float& Show(float* val) {
return val[DownpourCtrDoublePushValue::ShowIndex()]; return val[CtrDoublePushValue::ShowIndex()];
} }
static float& Click(float* val) { static float& Click(float* val) {
return val[DownpourCtrDoublePushValue::ClickIndex()]; return val[CtrDoublePushValue::ClickIndex()];
} }
static float& EmbedG(float* val) { static float& EmbedG(float* val) {
return val[DownpourCtrDoublePushValue::EmbedGIndex()]; return val[CtrDoublePushValue::EmbedGIndex()];
} }
static float* EmbedxG(float* val) { static float* EmbedxG(float* val) {
return val + DownpourCtrDoublePushValue::EmbedxGIndex(); return val + CtrDoublePushValue::EmbedxGIndex();
} }
}; };
struct DownpourCtrDoublePullValue { struct CtrDoublePullValue {
/* /*
float show; float show;
float click; float click;
...@@ -153,20 +141,20 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -153,20 +141,20 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int EmbedWIndex() { return 2; } static int EmbedWIndex() { return 2; }
static int EmbedxWIndex() { return 3; } static int EmbedxWIndex() { return 3; }
static float& Show(float* val) { static float& Show(float* val) {
return val[DownpourCtrDoublePullValue::ShowIndex()]; return val[CtrDoublePullValue::ShowIndex()];
} }
static float& Click(float* val) { static float& Click(float* val) {
return val[DownpourCtrDoublePullValue::ClickIndex()]; return val[CtrDoublePullValue::ClickIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrDoublePullValue::EmbedWIndex()]; return val[CtrDoublePullValue::EmbedWIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return val + DownpourCtrDoublePullValue::EmbedxWIndex(); return val + CtrDoublePullValue::EmbedxWIndex();
} }
}; };
DownpourCtrDoubleAccessor() {} CtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~CtrDoubleAccessor() {}
virtual int Initialize(); virtual int Initialize();
// 初始化AccessorInfo // 初始化AccessorInfo
virtual void InitAccessorInfo(); virtual void InitAccessorInfo();
...@@ -182,7 +170,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -182,7 +170,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
// update delta_score and unseen_days after save // update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) override; virtual void UpdateStatAfterSave(float* value, int param) override;
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
virtual bool save_ssd(float* value); virtual bool SaveSSD(float* value);
// virtual bool save_cache(float* value, int param, double // virtual bool save_cache(float* value, int param, double
// global_cache_threshold) override; // global_cache_threshold) override;
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
...@@ -206,14 +194,14 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -206,14 +194,14 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
virtual float GetField(float* value, const std::string& name) override { virtual float GetField(float* value, const std::string& name) override {
CHECK(name == "show"); CHECK(name == "show");
if (name == "show") { if (name == "show") {
return (float)DownpourCtrDoubleFeatureValue::Show(value); return (float)CtrDoubleFeatureValue::Show(value);
} }
return 0.0; return 0.0;
} }
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, show) // DEFINE_GET_INDEX(CtrDoubleFeatureValue, show)
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, click) // DEFINE_GET_INDEX(CtrDoubleFeatureValue, click)
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embed_w) // DEFINE_GET_INDEX(CtrDoubleFeatureValue, embed_w)
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embedx_w) // DEFINE_GET_INDEX(CtrDoubleFeatureValue, embedx_w)
private: private:
double ShowClickScore(double show, double click); double ShowClickScore(double show, double click);
...@@ -222,6 +210,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -222,6 +210,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
SparseValueSGDRule* _embedx_sgd_rule; SparseValueSGDRule* _embedx_sgd_rule;
float _show_click_decay_rate; float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold; int32_t _ssd_unseenday_threshold;
bool _show_scale = false;
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/distributed/ps/table/memory_dense_table.h" #include "paddle/fluid/distributed/ps/table/memory_dense_table.h"
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_double_accessor.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h" #include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" #include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_accessor.h" #include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
...@@ -39,6 +40,7 @@ REGISTER_PSCORE_CLASS(Table, MemorySparseTable); ...@@ -39,6 +40,7 @@ REGISTER_PSCORE_CLASS(Table, MemorySparseTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable); REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrDoubleAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, SparseAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, SparseAccessor);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
......
...@@ -153,6 +153,7 @@ message CtrAccessorParameter { ...@@ -153,6 +153,7 @@ message CtrAccessorParameter {
// will be delete in shrink_model // will be delete in shrink_model
optional int32 ssd_unseenday_threshold = 9 optional int32 ssd_unseenday_threshold = 9
[ default = 1 ]; // threshold to save ssd [ default = 1 ]; // threshold to save ssd
optional bool show_scale = 10 [ default = true ];
} }
message TensorAccessorParameter { message TensorAccessorParameter {
......
...@@ -258,6 +258,7 @@ message CtrAccessorParameter { ...@@ -258,6 +258,7 @@ message CtrAccessorParameter {
[ default = 0.8 ]; // threshold to shrink a feasign [ default = 0.8 ]; // threshold to shrink a feasign
optional float delete_after_unseen_days = 8 [ default = 30 ]; optional float delete_after_unseen_days = 8 [ default = 30 ];
optional int32 ssd_unseenday_threshold = 9 [ default = 1 ]; optional int32 ssd_unseenday_threshold = 9 [ default = 1 ];
optional bool show_scale = 10 [ default = true ];
} }
message TableAccessorSaveParameter { message TableAccessorSaveParameter {
......
...@@ -611,12 +611,15 @@ class DistributedStrategy(object): ...@@ -611,12 +611,15 @@ class DistributedStrategy(object):
"DownpourCtrAccessor") "DownpourCtrAccessor")
if accessor_class not in support_sparse_accessor_class: if accessor_class not in support_sparse_accessor_class:
raise ValueError( raise ValueError(
"support sparse_accessor_class: [''DownpourSparseValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor', 'DownpourDoubleUnitAccessor'], but actual %s" "support sparse_accessor_class: ['DownpourSparseValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor', 'DownpourDoubleUnitAccessor'], but actual %s"
% (accessor_class)) % (accessor_class))
if configs.get("use_cvm", True): if accessor_class.find("Double") >= 0:
table_data.accessor.accessor_class = 'CtrCommonAccessor' table_data.accessor.accessor_class = 'CtrDoubleAccessor'
else: else:
table_data.accessor.accessor_class = 'CtrCommonAccessor'
if not configs.get("use_cvm", True):
table_data.accessor.accessor_class = 'SparseAccessor' table_data.accessor.accessor_class = 'SparseAccessor'
table_data.accessor.embedx_dim = config.get('sparse_embedx_dim', 8) table_data.accessor.embedx_dim = config.get('sparse_embedx_dim', 8)
...@@ -624,6 +627,11 @@ class DistributedStrategy(object): ...@@ -624,6 +627,11 @@ class DistributedStrategy(object):
table_data.accessor.embedx_threshold = config.get( table_data.accessor.embedx_threshold = config.get(
'sparse_embedx_threshold', 10) 'sparse_embedx_threshold', 10)
if accessor_class == 'DownpourUnitAccessor':
table_data.accessor.ctr_accessor_param.show_scale = False
else:
table_data.accessor.ctr_accessor_param.show_scale = True
table_data.accessor.ctr_accessor_param.nonclk_coeff = config.get( table_data.accessor.ctr_accessor_param.nonclk_coeff = config.get(
'sparse_nonclk_coeff', 0.1) 'sparse_nonclk_coeff', 0.1)
table_data.accessor.ctr_accessor_param.click_coeff = config.get( table_data.accessor.ctr_accessor_param.click_coeff = config.get(
......
...@@ -310,9 +310,22 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -310,9 +310,22 @@ class TestStrategyConfig(unittest.TestCase):
"embed_sparse_optimizer": "std_adagrad" "embed_sparse_optimizer": "std_adagrad"
} }
strategy.fleet_desc_configs = configs strategy.fleet_desc_configs = configs
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.ctr_accessor_param.show_scale, False)
self.assertEqual(strategy.sparse_table_configs[0] self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.adagrad.initial_range, 0) .accessor.embed_sgd_param.adagrad.initial_range, 0)
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {}
configs['emb'] = {
"sparse_accessor_class": "DownpourCtrDoubleAccessor",
"embed_sparse_optimizer": "std_adagrad"
}
strategy.fleet_desc_configs = configs
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.adagrad.initial_range,
0.0001)
def test_trainer_desc_configs(self): def test_trainer_desc_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
configs = { configs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册