未验证 提交 0b0c2768 编写于 作者: Y yaoxuefeng 提交者: GitHub

modify api name of ps accessor (#41207)

* modify api name of ps accessor

* update

* code format
上级 8aef685b
...@@ -1520,7 +1520,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data, ...@@ -1520,7 +1520,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
merge_data_shell[i] = merge_data + i; merge_data_shell[i] = merge_data + i;
another_data_shell[i] = another_data + i; another_data_shell[i] = another_data + i;
} }
accessor->merge(merge_data_shell, another_data_shell, 1); accessor->Merge(merge_data_shell, another_data_shell, 1);
} }
int BrpcPsClient::push_sparse_async_shard_merge( int BrpcPsClient::push_sparse_async_shard_merge(
...@@ -1759,7 +1759,7 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1759,7 +1759,7 @@ void BrpcPsClient::push_dense_task_consume() {
async_task]() -> int { async_task]() -> int {
auto &tmp_task_vec = *(async_task->data()); auto &tmp_task_vec = *(async_task->data());
const float *merge_data = tmp_task_vec.data(); const float *merge_data = tmp_task_vec.data();
accessor->merge(&total_send_data, &merge_data, accessor->Merge(&total_send_data, &merge_data,
total_send_data_size); total_send_data_size);
#pragma optimize("", off) #pragma optimize("", off)
auto *debug_closure = closure; auto *debug_closure = closure;
......
...@@ -206,7 +206,8 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -206,7 +206,8 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
} }
auto res_data = butil::get_object<std::vector<float>>(); auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->select_size() / sizeof(float)); res_data->resize(num * table->value_accesor()->GetTableInfo(SELECT_SIZE) /
sizeof(float));
TableContext table_context; TableContext table_context;
table_context.value_type = Dense; table_context.value_type = Dense;
table_context.pull_context.values = res_data->data(); table_context.pull_context.values = res_data->data();
...@@ -385,7 +386,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -385,7 +386,7 @@ int32_t BrpcPsService::pull_sparse(Table *table,
CostTimer timer("pserver_server_pull_sparse"); CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->value_accesor()->select_dim(); auto dim = table->value_accesor()->GetTableInfo(SELECT_DIM);
thread_local std::string req_buffer; thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size); req_buffer.reserve(req_buffer_size);
......
...@@ -46,8 +46,8 @@ int32_t PSClient::configure( ...@@ -46,8 +46,8 @@ int32_t PSClient::configure(
auto *accessor = CREATE_PSCORE_CLASS( auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor, ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class()); work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor()); accessor->Configure(work_param.downpour_table_param(i).accessor());
accessor->initialize(); accessor->Initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset( _table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor); accessor);
} }
......
...@@ -174,7 +174,8 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -174,7 +174,8 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
auto* accessor = table_accessor(table_id); auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = table(table_id);
uint32_t num_per_shard = dense_dim_per_shard(accessor->fea_dim(), 1); uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(num_per_shard); region_buffer.resize(num_per_shard);
table_ptr->pull_dense(region_buffer.data(), region_buffer.size()); table_ptr->pull_dense(region_buffer.data(), region_buffer.size());
...@@ -219,7 +220,8 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -219,7 +220,8 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
auto* table_ptr = table(table_id); auto* table_ptr = table(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->fea_dim(), 1), 0); region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1),
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
...@@ -252,7 +254,7 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -252,7 +254,7 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
auto* table_ptr = table(table_id); auto* table_ptr = table(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->fea_dim(), 1)); region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1));
size_t data_size = region_buffer.size(); size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
......
...@@ -72,7 +72,7 @@ class ValueAccessor { ...@@ -72,7 +72,7 @@ class ValueAccessor {
ValueAccessor() {} ValueAccessor() {}
virtual ~ValueAccessor() {} virtual ~ValueAccessor() {}
virtual int configure(const TableAccessorParameter& parameter) { virtual int Configure(const TableAccessorParameter& parameter) {
_config = parameter; _config = parameter;
// data_convert结构体初始化 // data_convert结构体初始化
if (_config.table_accessor_save_param_size() != 0) { if (_config.table_accessor_save_param_size() != 0) {
...@@ -88,38 +88,15 @@ class ValueAccessor { ...@@ -88,38 +88,15 @@ class ValueAccessor {
} }
return 0; return 0;
} }
virtual int initialize() = 0; virtual int Initialize() = 0;
virtual void SetTableInfo(AccessorInfo& info) = 0; virtual void SetTableInfo(AccessorInfo& info) = 0;
virtual size_t GetTableInfo(InfoKey key) = 0; virtual size_t GetTableInfo(InfoKey key) = 0;
// value维度 virtual bool NeedExtendMF(float* value) { return false; }
virtual size_t dim() = 0; virtual bool HasMF(size_t size) { return false; }
// value各个维度的size
virtual size_t dim_size(size_t dim) = 0;
// value各维度相加总size
virtual size_t size() = 0;
// value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size() { return 0; }
virtual bool need_extend_mf(float* value) { return false; }
virtual bool has_mf(size_t size) { return false; }
// pull value维度
virtual size_t select_dim() = 0;
// pull value各个维度的size
virtual size_t select_dim_size(size_t dim) = 0;
// pull value各维度相加总size
virtual size_t select_size() = 0;
// push value维度
virtual size_t update_dim() = 0;
// push value各个维度的size
virtual size_t update_dim_size(size_t dim) = 0;
// push value各维度相加总size
virtual size_t update_size() = 0;
// fea total for dense
virtual size_t fea_dim() { return _config.fea_dim(); }
// converter for save // converter for save
virtual std::string get_converter(int param) { virtual std::string GetConverter(int param) {
auto itr = _data_coverter_map.find(param); auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) { if (itr == _data_coverter_map.end()) {
return ""; return "";
...@@ -128,7 +105,7 @@ class ValueAccessor { ...@@ -128,7 +105,7 @@ class ValueAccessor {
} }
} }
// deconverter for load // deconverter for load
virtual std::string get_deconverter(int param) { virtual std::string GetDeconverter(int param) {
auto itr = _data_coverter_map.find(param); auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) { if (itr == _data_coverter_map.end()) {
return ""; return "";
...@@ -137,47 +114,47 @@ class ValueAccessor { ...@@ -137,47 +114,47 @@ class ValueAccessor {
} }
} }
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool shrink(float* value) = 0; virtual bool Shrink(float* value) = 0;
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool save(float* value, int param) = 0; virtual bool Save(float* value, int param) = 0;
// update delta_score and unseen_days after save // update delta_score and unseen_days after save
virtual void update_stat_after_save(float* value, int param) {} virtual void UpdateStatAfterSave(float* value, int param) {}
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
virtual int32_t create(float** value, size_t num) = 0; virtual int32_t Create(float** value, size_t num) = 0;
virtual bool create_value(int type, const float* value) { return true; } virtual bool CreateValue(int type, const float* value) { return true; }
// 从values中选取到select_values中 // 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values, virtual int32_t Select(float** select_values, const float** values,
size_t num) = 0; size_t num) = 0;
// 将update_values聚合到一起 // 将update_values聚合到一起
virtual int32_t merge(float** update_values, virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num) = 0; const float** other_update_values, size_t num) = 0;
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it); // virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values, virtual int32_t Update(float** values, const float** update_values,
size_t num) = 0; size_t num) = 0;
// used to save model, will filter feature // used to save model, will filter feature
virtual std::string parse_to_string(const float* value, int param) = 0; virtual std::string ParseToString(const float* value, int param) = 0;
// parse value from string, used to load model // parse value from string, used to load model
virtual int32_t parse_from_string(const std::string& data, float* value) = 0; virtual int32_t ParseFromString(const std::string& data, float* value) = 0;
virtual FsDataConverter converter(int param) { virtual FsDataConverter Converter(int param) {
FsDataConverter data_convert; FsDataConverter data_convert;
data_convert.converter = this->get_converter(param); data_convert.converter = this->GetConverter(param);
data_convert.deconverter = this->get_deconverter(param); data_convert.deconverter = this->GetDeconverter(param);
return data_convert; return data_convert;
} }
virtual int set_weight(float** values, const float** update_values, virtual int SetWeight(float** values, const float** update_values,
size_t num) { size_t num) {
return 0; return 0;
} }
virtual float get_field(float* value, const std::string& name) { return 0.0; } virtual float GetField(float* value, const std::string& name) { return 0.0; }
#define DEFINE_GET_INDEX(class, field) \ #define DEFINE_GET_INDEX(class, field) \
virtual int get_##field##_index() override { return class ::field##_index(); } virtual int get_##field##_index() override { return class ::field##_index(); }
......
...@@ -232,9 +232,9 @@ int32_t CommonDenseTable::load(const std::string& path, ...@@ -232,9 +232,9 @@ int32_t CommonDenseTable::load(const std::string& path,
int load_param = atoi(param.c_str()); int load_param = atoi(param.c_str());
FsChannelConfig channel_config; FsChannelConfig channel_config;
channel_config.converter = _value_accesor->converter(load_param).converter; channel_config.converter = _value_accesor->Converter(load_param).converter;
channel_config.deconverter = channel_config.deconverter =
_value_accesor->converter(load_param).deconverter; _value_accesor->Converter(load_param).deconverter;
bool is_read_failed = false; bool is_read_failed = false;
int err_no = 0; int err_no = 0;
int retry_num = 0; int retry_num = 0;
...@@ -329,9 +329,9 @@ int32_t CommonDenseTable::save(const std::string& path, ...@@ -329,9 +329,9 @@ int32_t CommonDenseTable::save(const std::string& path,
"%s/part-%03d", table_dir(path).c_str(), _shard_idx); "%s/part-%03d", table_dir(path).c_str(), _shard_idx);
} }
_afs_client.remove(channel_config.path); _afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter = channel_config.deconverter =
_value_accesor->converter(save_param).deconverter; _value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false; bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param( std::vector<std::vector<std::string>> result_buffer_param(
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int CtrCommonAccessor::initialize() { int CtrCommonAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
...@@ -39,73 +39,72 @@ int CtrCommonAccessor::initialize() { ...@@ -39,73 +39,72 @@ int CtrCommonAccessor::initialize() {
} }
void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) { void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = Dim();
info.size = size(); info.size = Size();
info.select_dim = select_dim(); info.select_dim = SelectDim();
info.select_size = select_size(); info.select_size = SelectSize();
info.update_dim = update_dim(); info.update_dim = UpdateDim();
info.update_size = update_size(); info.update_size = UpdateSize();
info.mf_size = mf_size(); info.mf_size = MFSize();
info.fea_dim = fea_dim();
} }
size_t CtrCommonAccessor::GetTableInfo(InfoKey key) { size_t CtrCommonAccessor::GetTableInfo(InfoKey key) {
switch (key) { switch (key) {
case DIM: case DIM:
return dim(); return Dim();
case SIZE: case SIZE:
return size(); return Size();
case SELECT_DIM: case SELECT_DIM:
return select_dim(); return SelectDim();
case SELECT_SIZE: case SELECT_SIZE:
return select_size(); return SelectSize();
case UPDATE_DIM: case UPDATE_DIM:
return update_dim(); return UpdateDim();
case UPDATE_SIZE: case UPDATE_SIZE:
return update_size(); return UpdateSize();
case MF_SIZE: case MF_SIZE:
return mf_size(); return MFSize();
case FEA_DIM: default:
return fea_dim(); return 0;
} }
return 0; return 0;
} }
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::Dim() { return common_feature_value.Dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) { size_t CtrCommonAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return common_feature_value.dim_size(dim, embedx_dim); return common_feature_value.DimSize(dim, embedx_dim);
} }
size_t CtrCommonAccessor::size() { return common_feature_value.size(); } size_t CtrCommonAccessor::Size() { return common_feature_value.Size(); }
size_t CtrCommonAccessor::mf_size() { size_t CtrCommonAccessor::MFSize() {
return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) * return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) *
sizeof(float); // embedx embedx_g2sum sizeof(float); // embedx embedx_g2sum
} }
// pull value // pull value
size_t CtrCommonAccessor::select_dim() { size_t CtrCommonAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim; return 3 + embedx_dim;
} }
size_t CtrCommonAccessor::select_dim_size(size_t dim) { return sizeof(float); } size_t CtrCommonAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::select_size() { return select_dim() * sizeof(float); } size_t CtrCommonAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value // push value
size_t CtrCommonAccessor::update_dim() { size_t CtrCommonAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim; return 4 + embedx_dim;
} }
size_t CtrCommonAccessor::update_dim_size(size_t dim) { return sizeof(float); } size_t CtrCommonAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::update_size() { return update_dim() * sizeof(float); } size_t CtrCommonAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool CtrCommonAccessor::shrink(float* value) { bool CtrCommonAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delete_after_unseen_days = auto delete_after_unseen_days =
...@@ -113,12 +112,12 @@ bool CtrCommonAccessor::shrink(float* value) { ...@@ -113,12 +112,12 @@ bool CtrCommonAccessor::shrink(float* value) {
auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first // time_decay first
common_feature_value.show(value) *= _show_click_decay_rate; common_feature_value.Show(value) *= _show_click_decay_rate;
common_feature_value.click(value) *= _show_click_decay_rate; common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after // shrink after
auto score = show_click_score(common_feature_value.show(value), auto score = show_click_score(common_feature_value.Show(value),
common_feature_value.click(value)); common_feature_value.Click(value));
auto unseen_days = common_feature_value.unseen_days(value); auto unseen_days = common_feature_value.unseen_days(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) { if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true; return true;
...@@ -126,7 +125,7 @@ bool CtrCommonAccessor::shrink(float* value) { ...@@ -126,7 +125,7 @@ bool CtrCommonAccessor::shrink(float* value) {
return false; return false;
} }
bool CtrCommonAccessor::save(float* value, int param) { bool CtrCommonAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -142,8 +141,8 @@ bool CtrCommonAccessor::save(float* value, int param) { ...@@ -142,8 +141,8 @@ bool CtrCommonAccessor::save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
if (show_click_score(common_feature_value.show(value), if (show_click_score(common_feature_value.Show(value),
common_feature_value.click(value)) >= common_feature_value.Click(value)) >=
base_threshold && base_threshold &&
common_feature_value.delta_score(value) >= delta_threshold && common_feature_value.delta_score(value) >= delta_threshold &&
common_feature_value.unseen_days(value) <= delta_keep_days) { common_feature_value.unseen_days(value) <= delta_keep_days) {
...@@ -171,7 +170,7 @@ bool CtrCommonAccessor::save(float* value, int param) { ...@@ -171,7 +170,7 @@ bool CtrCommonAccessor::save(float* value, int param) {
} }
} }
void CtrCommonAccessor::update_stat_after_save(float* value, int param) { void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -180,8 +179,8 @@ void CtrCommonAccessor::update_stat_after_save(float* value, int param) { ...@@ -180,8 +179,8 @@ void CtrCommonAccessor::update_stat_after_save(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
if (show_click_score(common_feature_value.show(value), if (show_click_score(common_feature_value.Show(value),
common_feature_value.click(value)) >= common_feature_value.Click(value)) >=
base_threshold && base_threshold &&
common_feature_value.delta_score(value) >= delta_threshold && common_feature_value.delta_score(value) >= delta_threshold &&
common_feature_value.unseen_days(value) <= delta_keep_days) { common_feature_value.unseen_days(value) <= delta_keep_days) {
...@@ -198,52 +197,52 @@ void CtrCommonAccessor::update_stat_after_save(float* value, int param) { ...@@ -198,52 +197,52 @@ void CtrCommonAccessor::update_stat_after_save(float* value, int param) {
} }
} }
int32_t CtrCommonAccessor::create(float** values, size_t num) { int32_t CtrCommonAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[common_feature_value.unseen_days_index()] = 0; value[common_feature_value.unseen_days_index()] = 0;
value[common_feature_value.delta_score_index()] = 0; value[common_feature_value.delta_score_index()] = 0;
value[common_feature_value.show_index()] = 0; value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.click_index()] = 0; value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.slot_index()] = -1; value[common_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->init_value(
value + common_feature_value.embed_w_index(), value + common_feature_value.Embed_W_Index(),
value + common_feature_value.embed_g2sum_index()); value + common_feature_value.embed_g2sum_index());
_embedx_sgd_rule->init_value( _embedx_sgd_rule->init_value(
value + common_feature_value.embedx_w_index(), value + common_feature_value.Embedx_W_Index(),
value + common_feature_value.embedx_g2sum_index(), false); value + common_feature_value.embedx_g2sum_index(), false);
} }
return 0; return 0;
} }
bool CtrCommonAccessor::need_extend_mf(float* value) { bool CtrCommonAccessor::NeedExtendMF(float* value) {
float show = value[common_feature_value.show_index()]; float show = value[common_feature_value.ShowIndex()];
float click = value[common_feature_value.click_index()]; float click = value[common_feature_value.ClickIndex()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() + float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff(); click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
bool CtrCommonAccessor::has_mf(size_t size) { bool CtrCommonAccessor::HasMF(size_t size) {
return size > common_feature_value.embedx_g2sum_index(); return size > common_feature_value.embedx_g2sum_index();
} }
// from CommonFeatureValue to CtrCommonPullValue // from CommonFeatureValue to CtrCommonPullValue
int32_t CtrCommonAccessor::select(float** select_values, const float** values, int32_t CtrCommonAccessor::Select(float** select_values, const float** values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item]; float* select_value = select_values[value_item];
const float* value = values[value_item]; const float* value = values[value_item];
select_value[CtrCommonPullValue::show_index()] = select_value[CtrCommonPullValue::ShowIndex()] =
value[common_feature_value.show_index()]; value[common_feature_value.ShowIndex()];
select_value[CtrCommonPullValue::click_index()] = select_value[CtrCommonPullValue::ClickIndex()] =
value[common_feature_value.click_index()]; value[common_feature_value.ClickIndex()];
select_value[CtrCommonPullValue::embed_w_index()] = select_value[CtrCommonPullValue::Embed_W_Index()] =
value[common_feature_value.embed_w_index()]; value[common_feature_value.Embed_W_Index()];
memcpy(select_value + CtrCommonPullValue::embedx_w_index(), memcpy(select_value + CtrCommonPullValue::Embedx_W_Index(),
value + common_feature_value.embedx_w_index(), value + common_feature_value.Embedx_W_Index(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -252,16 +251,16 @@ int32_t CtrCommonAccessor::select(float** select_values, const float** values, ...@@ -252,16 +251,16 @@ int32_t CtrCommonAccessor::select(float** select_values, const float** values,
// from CtrCommonPushValue to CtrCommonPushValue // from CtrCommonPushValue to CtrCommonPushValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t CtrCommonAccessor::merge(float** update_values, int32_t CtrCommonAccessor::Merge(float** update_values,
const float** other_update_values, const float** other_update_values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
size_t total_dim = CtrCommonPushValue::dim(embedx_dim); size_t total_dim = CtrCommonPushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item]; const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) { for (auto i = 0u; i < total_dim; ++i) {
if (i != CtrCommonPushValue::slot_index()) { if (i != CtrCommonPushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
} }
...@@ -272,43 +271,43 @@ int32_t CtrCommonAccessor::merge(float** update_values, ...@@ -272,43 +271,43 @@ int32_t CtrCommonAccessor::merge(float** update_values,
// from CtrCommonPushValue to CommonFeatureValue // from CtrCommonPushValue to CommonFeatureValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t CtrCommonAccessor::update(float** update_values, int32_t CtrCommonAccessor::Update(float** update_values,
const float** push_values, size_t num) { const float** push_values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* push_value = push_values[value_item]; const float* push_value = push_values[value_item];
float push_show = push_value[CtrCommonPushValue::show_index()]; float push_show = push_value[CtrCommonPushValue::ShowIndex()];
float push_click = push_value[CtrCommonPushValue::click_index()]; float push_click = push_value[CtrCommonPushValue::ClickIndex()];
float slot = push_value[CtrCommonPushValue::slot_index()]; float slot = push_value[CtrCommonPushValue::SlotIndex()];
update_value[common_feature_value.show_index()] += push_show; update_value[common_feature_value.ShowIndex()] += push_show;
update_value[common_feature_value.click_index()] += push_click; update_value[common_feature_value.ClickIndex()] += push_click;
update_value[common_feature_value.slot_index()] = slot; update_value[common_feature_value.SlotIndex()] = slot;
update_value[common_feature_value.delta_score_index()] += update_value[common_feature_value.delta_score_index()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.unseen_days_index()] = 0; update_value[common_feature_value.unseen_days_index()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->update_value(
update_value + common_feature_value.embed_w_index(), update_value + common_feature_value.Embed_W_Index(),
update_value + common_feature_value.embed_g2sum_index(), update_value + common_feature_value.embed_g2sum_index(),
push_value + CtrCommonPushValue::embed_g_index()); push_value + CtrCommonPushValue::Embed_G_Index());
_embedx_sgd_rule->update_value( _embedx_sgd_rule->update_value(
update_value + common_feature_value.embedx_w_index(), update_value + common_feature_value.Embedx_W_Index(),
update_value + common_feature_value.embedx_g2sum_index(), update_value + common_feature_value.embedx_g2sum_index(),
push_value + CtrCommonPushValue::embedx_g_index()); push_value + CtrCommonPushValue::Embedx_G_Index());
} }
return 0; return 0;
} }
bool CtrCommonAccessor::create_value(int stage, const float* value) { bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull // stage == 0, pull
// stage == 1, push // stage == 1, push
if (stage == 0) { if (stage == 0) {
return true; return true;
} else if (stage == 1) { } else if (stage == 1) {
// operation // operation
auto show = CtrCommonPushValue::show(const_cast<float*>(value)); auto show = CtrCommonPushValue::Show(const_cast<float*>(value));
auto click = CtrCommonPushValue::click(const_cast<float*>(value)); auto click = CtrCommonPushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
...@@ -329,34 +328,34 @@ float CtrCommonAccessor::show_click_score(float show, float click) { ...@@ -329,34 +328,34 @@ float CtrCommonAccessor::show_click_score(float show, float click) {
return (show - click) * nonclk_coeff + click * click_coeff; return (show - click) * nonclk_coeff + click * click_coeff;
} }
std::string CtrCommonAccessor::parse_to_string(const float* v, int param) { std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
thread_local std::ostringstream os; thread_local std::ostringstream os;
os.clear(); os.clear();
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5]; << v[5];
for (int i = common_feature_value.embed_g2sum_index(); for (int i = common_feature_value.embed_g2sum_index();
i < common_feature_value.embedx_w_index(); i++) { i < common_feature_value.Embedx_W_Index(); i++) {
os << " " << v[i]; os << " " << v[i];
} }
auto show = common_feature_value.show(const_cast<float*>(v)); auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.click(const_cast<float*>(v)); auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold() && if (score >= _config.embedx_threshold() &&
param > common_feature_value.embedx_w_index()) { param > common_feature_value.Embedx_W_Index()) {
for (auto i = common_feature_value.embedx_w_index(); for (auto i = common_feature_value.Embedx_W_Index();
i < common_feature_value.dim(); ++i) { i < common_feature_value.Dim(); ++i) {
os << " " << v[i]; os << " " << v[i];
} }
} }
return os.str(); return os.str();
} }
int CtrCommonAccessor::parse_from_string(const std::string& str, float* value) { int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value( _embedx_sgd_rule->init_value(
value + common_feature_value.embedx_w_index(), value + common_feature_value.Embedx_W_Index(),
value + common_feature_value.embedx_g2sum_index()); value + common_feature_value.embedx_g2sum_index());
auto ret = paddle::string::str_to_float(str.data(), value); auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret; CHECK(ret >= 6) << "expect more than 6 real:" << ret;
......
...@@ -40,27 +40,27 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -40,27 +40,27 @@ class CtrCommonAccessor : public ValueAccessor {
std::<vector>float embedx_g2sum; std::<vector>float embedx_g2sum;
*/ */
int dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } int Dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
int dim_size(size_t dim, int embedx_dim) { return sizeof(float); } int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int size() { return dim() * sizeof(float); } int Size() { return Dim() * sizeof(float); }
int slot_index() { return 0; } int SlotIndex() { return 0; }
int unseen_days_index() { return slot_index() + 1; } int unseen_days_index() { return SlotIndex() + 1; }
int delta_score_index() { return unseen_days_index() + 1; } int delta_score_index() { return unseen_days_index() + 1; }
int show_index() { return delta_score_index() + 1; } int ShowIndex() { return delta_score_index() + 1; }
int click_index() { return show_index() + 1; } int ClickIndex() { return ShowIndex() + 1; }
int embed_w_index() { return click_index() + 1; } int Embed_W_Index() { return ClickIndex() + 1; }
int embed_g2sum_index() { return embed_w_index() + 1; } int embed_g2sum_index() { return Embed_W_Index() + 1; }
int embedx_w_index() { return embed_g2sum_index() + embed_sgd_dim; } int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; }
int embedx_g2sum_index() { return embedx_w_index() + embedx_dim; } int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; } float& unseen_days(float* val) { return val[unseen_days_index()]; }
float& delta_score(float* val) { return val[delta_score_index()]; } float& delta_score(float* val) { return val[delta_score_index()]; }
float& show(float* val) { return val[show_index()]; } float& Show(float* val) { return val[ShowIndex()]; }
float& click(float* val) { return val[click_index()]; } float& Click(float* val) { return val[ClickIndex()]; }
float& slot(float* val) { return val[slot_index()]; } float& Slot(float* val) { return val[SlotIndex()]; }
float& embed_w(float* val) { return val[embed_w_index()]; } float& EmbedW(float* val) { return val[Embed_W_Index()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& embedx_w(float* val) { return val[embedx_w_index()]; } float& EmbedxW(float* val) { return val[Embedx_W_Index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
int embed_sgd_dim; int embed_sgd_dim;
...@@ -77,31 +77,31 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -77,31 +77,31 @@ class CtrCommonAccessor : public ValueAccessor {
std::vector<float> embedx_g; std::vector<float> embedx_g;
*/ */
static int dim(int embedx_dim) { return 4 + embedx_dim; } static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int dim_size(int dim, int embedx_dim) { return sizeof(float); } static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int slot_index() { return 0; } static int SlotIndex() { return 0; }
static int show_index() { return CtrCommonPushValue::slot_index() + 1; } static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; }
static int click_index() { return CtrCommonPushValue::show_index() + 1; } static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; }
static int embed_g_index() { return CtrCommonPushValue::click_index() + 1; } static int Embed_G_Index() { return CtrCommonPushValue::ClickIndex() + 1; }
static int embedx_g_index() { static int Embedx_G_Index() {
return CtrCommonPushValue::embed_g_index() + 1; return CtrCommonPushValue::Embed_G_Index() + 1;
} }
static float& slot(float* val) { static float& Slot(float* val) {
return val[CtrCommonPushValue::slot_index()]; return val[CtrCommonPushValue::SlotIndex()];
} }
static float& show(float* val) { static float& Show(float* val) {
return val[CtrCommonPushValue::show_index()]; return val[CtrCommonPushValue::ShowIndex()];
} }
static float& click(float* val) { static float& Click(float* val) {
return val[CtrCommonPushValue::click_index()]; return val[CtrCommonPushValue::ClickIndex()];
} }
static float& embed_g(float* val) { static float& EmbedG(float* val) {
return val[CtrCommonPushValue::embed_g_index()]; return val[CtrCommonPushValue::Embed_G_Index()];
} }
static float* embedx_g(float* val) { static float* EmbedxG(float* val) {
return val + CtrCommonPushValue::embedx_g_index(); return val + CtrCommonPushValue::Embedx_G_Index();
} }
}; };
...@@ -113,90 +113,90 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -113,90 +113,90 @@ class CtrCommonAccessor : public ValueAccessor {
std::vector<float> embedx_w; std::vector<float> embedx_w;
*/ */
static int dim(int embedx_dim) { return 3 + embedx_dim; } static int Dim(int embedx_dim) { return 3 + embedx_dim; }
static int dim_size(size_t dim) { return sizeof(float); } static int DimSize(size_t dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int show_index() { return 0; } static int ShowIndex() { return 0; }
static int click_index() { return 1; } static int ClickIndex() { return 1; }
static int embed_w_index() { return 2; } static int Embed_W_Index() { return 2; }
static int embedx_w_index() { return 3; } static int Embedx_W_Index() { return 3; }
static float& show(float* val) { static float& Show(float* val) {
return val[CtrCommonPullValue::show_index()]; return val[CtrCommonPullValue::ShowIndex()];
} }
static float& click(float* val) { static float& Click(float* val) {
return val[CtrCommonPullValue::click_index()]; return val[CtrCommonPullValue::ClickIndex()];
} }
static float& embed_w(float* val) { static float& EmbedW(float* val) {
return val[CtrCommonPullValue::embed_w_index()]; return val[CtrCommonPullValue::Embed_W_Index()];
} }
static float* embedx_w(float* val) { static float* EmbedxW(float* val) {
return val + CtrCommonPullValue::embedx_w_index(); return val + CtrCommonPullValue::Embedx_W_Index();
} }
}; };
CtrCommonAccessor() {} CtrCommonAccessor() {}
virtual int initialize(); virtual int Initialize();
virtual ~CtrCommonAccessor() {} virtual ~CtrCommonAccessor() {}
virtual void SetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key); virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); size_t Dim();
// value各个维度的size // value各个维度的size
virtual size_t dim_size(size_t dim); size_t DimSize(size_t dim);
// value各维度相加总size // value各维度相加总size
virtual size_t size(); size_t Size();
// value中mf动态长度部分总size大小, sparse下生效 // value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size(); size_t MFSize();
// pull value维度 // pull value维度
virtual size_t select_dim(); size_t SelectDim();
// pull value各个维度的size // pull value各个维度的size
virtual size_t select_dim_size(size_t dim); size_t SelectDimSize(size_t dim);
// pull value各维度相加总size // pull value各维度相加总size
virtual size_t select_size(); size_t SelectSize();
// push value维度 // push value维度
virtual size_t update_dim(); size_t UpdateDim();
// push value各个维度的size // push value各个维度的size
virtual size_t update_dim_size(size_t dim); size_t UpdateDimSize(size_t dim);
// push value各维度相加总size // push value各维度相加总size
virtual size_t update_size(); size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool shrink(float* value); virtual bool Shrink(float* value);
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
// virtual bool save_ssd(float* value); // virtual bool save_ssd(float* value);
virtual bool need_extend_mf(float* value); virtual bool NeedExtendMF(float* value);
virtual bool has_mf(size_t size); virtual bool HasMF(size_t size);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
// param = 1, save delta feature // param = 1, save delta feature
// param = 2, save xbox base feature // param = 2, save xbox base feature
bool save(float* value, int param) override; bool Save(float* value, int param) override;
// update delta_score and unseen_days after save // update delta_score and unseen_days after save
void update_stat_after_save(float* value, int param) override; void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕 // 要求value的内存由外部调用者分配完毕
virtual int32_t create(float** value, size_t num); virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中 // 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values, virtual int32_t Select(float** select_values, const float** values,
size_t num); size_t num);
// 将update_values聚合到一起 // 将update_values聚合到一起
virtual int32_t merge(float** update_values, virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num); const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it); // virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values, virtual int32_t Update(float** values, const float** update_values,
size_t num); size_t num);
std::string parse_to_string(const float* value, int param) override; std::string ParseToString(const float* value, int param) override;
int32_t parse_from_string(const std::string& str, float* v) override; int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool create_value(int type, const float* value); virtual bool CreateValue(int type, const float* value);
// 这个接口目前只用来取show // 这个接口目前只用来取show
float get_field(float* value, const std::string& name) override { float GetField(float* value, const std::string& name) override {
// CHECK(name == "show"); // CHECK(name == "show");
if (name == "show") { if (name == "show") {
return common_feature_value.show(value); return common_feature_value.Show(value);
} }
return 0.0; return 0.0;
} }
......
...@@ -38,36 +38,36 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -38,36 +38,36 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
float embedx_g2sum; float embedx_g2sum;
std::vector<float> embedx_w; std::vector<float> embedx_w;
*/ */
static int dim(int embedx_dim) { return 8 + embedx_dim; } static int Dim(int embedx_dim) { return 8 + embedx_dim; }
static int dim_size(size_t dim, int embedx_dim) { return sizeof(float); } static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { static int Size(int embedx_dim) {
return (dim(embedx_dim) + 2) * sizeof(float); return (Dim(embedx_dim) + 2) * sizeof(float);
} }
static int unseen_days_index() { return 0; } static int unseen_days_index() { return 0; }
static int delta_score_index() { static int delta_score_index() {
return DownpourCtrDoubleFeatureValue::unseen_days_index() + 1; return DownpourCtrDoubleFeatureValue::unseen_days_index() + 1;
} }
static int show_index() { static int ShowIndex() {
return DownpourCtrDoubleFeatureValue::delta_score_index() + 1; return DownpourCtrDoubleFeatureValue::delta_score_index() + 1;
} }
// show is double // show is double
static int click_index() { static int ClickIndex() {
return DownpourCtrDoubleFeatureValue::show_index() + 2; return DownpourCtrDoubleFeatureValue::ShowIndex() + 2;
} }
// click is double // click is double
static int embed_w_index() { static int Embed_W_Index() {
return DownpourCtrDoubleFeatureValue::click_index() + 2; return DownpourCtrDoubleFeatureValue::ClickIndex() + 2;
} }
static int embed_g2sum_index() { static int embed_g2sum_index() {
return DownpourCtrDoubleFeatureValue::embed_w_index() + 1; return DownpourCtrDoubleFeatureValue::Embed_W_Index() + 1;
} }
static int slot_index() { static int SlotIndex() {
return DownpourCtrDoubleFeatureValue::embed_g2sum_index() + 1; return DownpourCtrDoubleFeatureValue::embed_g2sum_index() + 1;
} }
static int embedx_g2sum_index() { static int embedx_g2sum_index() {
return DownpourCtrDoubleFeatureValue::slot_index() + 1; return DownpourCtrDoubleFeatureValue::SlotIndex() + 1;
} }
static int embedx_w_index() { static int Embedx_W_Index() {
return DownpourCtrDoubleFeatureValue::embedx_g2sum_index() + 1; return DownpourCtrDoubleFeatureValue::embedx_g2sum_index() + 1;
} }
static float& unseen_days(float* val) { static float& unseen_days(float* val) {
...@@ -76,17 +76,17 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -76,17 +76,17 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static float& delta_score(float* val) { static float& delta_score(float* val) {
return val[DownpourCtrDoubleFeatureValue::delta_score_index()]; return val[DownpourCtrDoubleFeatureValue::delta_score_index()];
} }
static double& show(float* val) { static double& Show(float* val) {
return ((double*)(val + DownpourCtrDoubleFeatureValue::show_index()))[0]; return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0];
} }
static double& click(float* val) { static double& Click(float* val) {
return ((double*)(val + DownpourCtrDoubleFeatureValue::click_index()))[0]; return ((double*)(val + DownpourCtrDoubleFeatureValue::ClickIndex()))[0];
} }
static float& slot(float* val) { static float& Slot(float* val) {
return val[DownpourCtrDoubleFeatureValue::slot_index()]; return val[DownpourCtrDoubleFeatureValue::SlotIndex()];
} }
static float& embed_w(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrDoubleFeatureValue::embed_w_index()]; return val[DownpourCtrDoubleFeatureValue::Embed_W_Index()];
} }
static float& embed_g2sum(float* val) { static float& embed_g2sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::embed_g2sum_index()]; return val[DownpourCtrDoubleFeatureValue::embed_g2sum_index()];
...@@ -94,8 +94,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -94,8 +94,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static float& embedx_g2sum(float* val) { static float& embedx_g2sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::embedx_g2sum_index()]; return val[DownpourCtrDoubleFeatureValue::embedx_g2sum_index()];
} }
static float* embedx_w(float* val) { static float* EmbedxW(float* val) {
return (val + DownpourCtrDoubleFeatureValue::embedx_w_index()); return (val + DownpourCtrDoubleFeatureValue::Embedx_W_Index());
} }
}; };
struct DownpourCtrDoublePushValue { struct DownpourCtrDoublePushValue {
...@@ -106,36 +106,36 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -106,36 +106,36 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
float embed_g; float embed_g;
std::vector<float> embedx_g; std::vector<float> embedx_g;
*/ */
static int dim(int embedx_dim) { return 4 + embedx_dim; } static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int dim_size(int dim, int embedx_dim) { return sizeof(float); } static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int slot_index() { return 0; } static int SlotIndex() { return 0; }
static int show_index() { static int ShowIndex() {
return DownpourCtrDoublePushValue::slot_index() + 1; return DownpourCtrDoublePushValue::SlotIndex() + 1;
} }
static int click_index() { static int ClickIndex() {
return DownpourCtrDoublePushValue::show_index() + 1; return DownpourCtrDoublePushValue::ShowIndex() + 1;
} }
static int embed_g_index() { static int Embed_G_Index() {
return DownpourCtrDoublePushValue::click_index() + 1; return DownpourCtrDoublePushValue::ClickIndex() + 1;
} }
static int embedx_g_index() { static int Embedx_G_Index() {
return DownpourCtrDoublePushValue::embed_g_index() + 1; return DownpourCtrDoublePushValue::Embed_G_Index() + 1;
} }
static float& slot(float* val) { static float& Slot(float* val) {
return val[DownpourCtrDoublePushValue::slot_index()]; return val[DownpourCtrDoublePushValue::SlotIndex()];
} }
static float& show(float* val) { static float& Show(float* val) {
return val[DownpourCtrDoublePushValue::show_index()]; return val[DownpourCtrDoublePushValue::ShowIndex()];
} }
static float& click(float* val) { static float& Click(float* val) {
return val[DownpourCtrDoublePushValue::click_index()]; return val[DownpourCtrDoublePushValue::ClickIndex()];
} }
static float& embed_g(float* val) { static float& EmbedG(float* val) {
return val[DownpourCtrDoublePushValue::embed_g_index()]; return val[DownpourCtrDoublePushValue::Embed_G_Index()];
} }
static float* embedx_g(float* val) { static float* EmbedxG(float* val) {
return val + DownpourCtrDoublePushValue::embedx_g_index(); return val + DownpourCtrDoublePushValue::Embedx_G_Index();
} }
}; };
struct DownpourCtrDoublePullValue { struct DownpourCtrDoublePullValue {
...@@ -145,88 +145,88 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -145,88 +145,88 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
float embed_w; float embed_w;
std::vector<float> embedx_w; std::vector<float> embedx_w;
*/ */
static int dim(int embedx_dim) { return 3 + embedx_dim; } static int Dim(int embedx_dim) { return 3 + embedx_dim; }
static int dim_size(size_t dim) { return sizeof(float); } static int DimSize(size_t dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int show_index() { return 0; } static int ShowIndex() { return 0; }
static int click_index() { return 1; } static int ClickIndex() { return 1; }
static int embed_w_index() { return 2; } static int Embed_W_Index() { return 2; }
static int embedx_w_index() { return 3; } static int Embedx_W_Index() { return 3; }
static float& show(float* val) { static float& Show(float* val) {
return val[DownpourCtrDoublePullValue::show_index()]; return val[DownpourCtrDoublePullValue::ShowIndex()];
} }
static float& click(float* val) { static float& Click(float* val) {
return val[DownpourCtrDoublePullValue::click_index()]; return val[DownpourCtrDoublePullValue::ClickIndex()];
} }
static float& embed_w(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrDoublePullValue::embed_w_index()]; return val[DownpourCtrDoublePullValue::Embed_W_Index()];
} }
static float* embedx_w(float* val) { static float* EmbedxW(float* val) {
return val + DownpourCtrDoublePullValue::embedx_w_index(); return val + DownpourCtrDoublePullValue::Embedx_W_Index();
} }
}; };
DownpourCtrDoubleAccessor() {} DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key); virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); size_t Dim();
// value各个维度的size // value各个维度的size
virtual size_t dim_size(size_t dim); size_t DimSize(size_t dim);
// value各维度相加总size // value各维度相加总size
virtual size_t size(); size_t Size();
// value中mf动态长度部分总size大小, sparse下生效 // value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size(); size_t MFSize();
// pull value维度 // pull value维度
virtual size_t select_dim(); size_t SelectDim();
// pull value各个维度的size // pull value各个维度的size
virtual size_t select_dim_size(size_t dim); size_t SelectDimSize(size_t dim);
// pull value各维度相加总size // pull value各维度相加总size
virtual size_t select_size(); size_t SelectSize();
// push value维度 // push value维度
virtual size_t update_dim(); size_t UpdateDim();
// push value各个维度的size // push value各个维度的size
virtual size_t update_dim_size(size_t dim); size_t UpdateDimSize(size_t dim);
// push value各维度相加总size // push value各维度相加总size
virtual size_t update_size(); size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool shrink(float* value); virtual bool Shrink(float* value);
virtual bool need_extend_mf(float* value); virtual bool NeedExtendMF(float* value);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
// param = 1, save delta feature // param = 1, save delta feature
// param = 3, save all feature with time decay // param = 3, save all feature with time decay
virtual bool save(float* value, int param) override; virtual bool Save(float* value, int param) override;
// update delta_score and unseen_days after save // update delta_score and unseen_days after save
virtual void update_stat_after_save(float* value, int param) override; virtual void UpdateStatAfterSave(float* value, int param) override;
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
virtual bool save_ssd(float* value); virtual bool save_ssd(float* value);
// virtual bool save_cache(float* value, int param, double // virtual bool save_cache(float* value, int param, double
// global_cache_threshold) override; // global_cache_threshold) override;
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕 // 要求value的内存由外部调用者分配完毕
virtual int32_t create(float** value, size_t num); virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中 // 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values, virtual int32_t Select(float** select_values, const float** values,
size_t num); size_t num);
// 将update_values聚合到一起 // 将update_values聚合到一起
virtual int32_t merge(float** update_values, virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num); const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it); // virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values, virtual int32_t Update(float** values, const float** update_values,
size_t num); size_t num);
virtual std::string parse_to_string(const float* value, int param) override; virtual std::string ParseToString(const float* value, int param) override;
virtual int32_t parse_from_string(const std::string& str, float* v) override; virtual int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool create_value(int type, const float* value); virtual bool CreateValue(int type, const float* value);
//这个接口目前只用来取show //这个接口目前只用来取show
virtual float get_field(float* value, const std::string& name) override { virtual float GetField(float* value, const std::string& name) override {
CHECK(name == "show"); CHECK(name == "show");
if (name == "show") { if (name == "show") {
return (float)DownpourCtrDoubleFeatureValue::show(value); return (float)DownpourCtrDoubleFeatureValue::Show(value);
} }
return 0.0; return 0.0;
} }
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int DownpourCtrAccessor::initialize() { int DownpourCtrAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
...@@ -38,86 +38,77 @@ int DownpourCtrAccessor::initialize() { ...@@ -38,86 +38,77 @@ int DownpourCtrAccessor::initialize() {
} }
void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) { void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = Dim();
info.size = size(); info.size = Size();
info.select_dim = select_dim(); info.select_dim = SelectDim();
info.select_size = select_size(); info.select_size = SelectSize();
info.update_dim = update_dim(); info.update_dim = UpdateDim();
info.update_size = update_size(); info.update_size = UpdateSize();
info.mf_size = mf_size(); info.mf_size = MFSize();
info.fea_dim = fea_dim();
} }
size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) { size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) {
switch (key) { switch (key) {
case DIM: case DIM:
return dim(); return Dim();
case SIZE: case SIZE:
return size(); return Size();
case SELECT_DIM: case SELECT_DIM:
return select_dim(); return SelectDim();
case SELECT_SIZE: case SELECT_SIZE:
return select_size(); return SelectSize();
case UPDATE_DIM: case UPDATE_DIM:
return update_dim(); return UpdateDim();
case UPDATE_SIZE: case UPDATE_SIZE:
return update_size(); return UpdateSize();
case MF_SIZE: case MF_SIZE:
return mf_size(); return MFSize();
case FEA_DIM: default:
return fea_dim(); return 0;
} }
return 0; return 0;
} }
size_t DownpourCtrAccessor::dim() { size_t DownpourCtrAccessor::Dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim); return DownpourCtrFeatureValue::Dim(embedx_dim);
} }
size_t DownpourCtrAccessor::dim_size(size_t dim) { size_t DownpourCtrAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim_size(dim, embedx_dim); return DownpourCtrFeatureValue::DimSize(dim, embedx_dim);
} }
size_t DownpourCtrAccessor::size() { size_t DownpourCtrAccessor::Size() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::size(embedx_dim); return DownpourCtrFeatureValue::Size(embedx_dim);
} }
size_t DownpourCtrAccessor::mf_size() { size_t DownpourCtrAccessor::MFSize() {
return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum
} }
// pull value // pull value
size_t DownpourCtrAccessor::select_dim() { size_t DownpourCtrAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim; return 3 + embedx_dim;
} }
size_t DownpourCtrAccessor::select_dim_size(size_t dim) { size_t DownpourCtrAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
return sizeof(float);
}
size_t DownpourCtrAccessor::select_size() { size_t DownpourCtrAccessor::SelectSize() { return SelectDim() * sizeof(float); }
return select_dim() * sizeof(float);
}
// push value // push value
size_t DownpourCtrAccessor::update_dim() { size_t DownpourCtrAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim; return 4 + embedx_dim;
} }
size_t DownpourCtrAccessor::update_dim_size(size_t dim) { size_t DownpourCtrAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
return sizeof(float);
}
size_t DownpourCtrAccessor::update_size() { size_t DownpourCtrAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
return update_dim() * sizeof(float);
}
bool DownpourCtrAccessor::shrink(float* value) { bool DownpourCtrAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); // auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
...@@ -134,9 +125,9 @@ bool DownpourCtrAccessor::shrink(float* value) { ...@@ -134,9 +125,9 @@ bool DownpourCtrAccessor::shrink(float* value) {
return true; return true;
} }
auto show_right = auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
auto click_right = auto click_right =
DownpourCtrFeatureValue::click(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
// shrink after // shrink after
auto score = show_click_score(show_right, click_right); auto score = show_click_score(show_right, click_right);
...@@ -175,15 +166,15 @@ bool DownpourCtrAccessor::save_ssd(float* value) { ...@@ -175,15 +166,15 @@ bool DownpourCtrAccessor::save_ssd(float* value) {
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); // auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
// int16_t day_diff = _day_id - unseen_days; // int16_t day_diff = _day_id - unseen_days;
// if (show_click_score(DownpourCtrFeatureValue::show(value), // if (show_click_score(DownpourCtrFeatureValue::Show(value),
// DownpourCtrFeatureValue::click(value)) >= base_threshold // DownpourCtrFeatureValue::Click(value)) >= base_threshold
// && day_diff <= delta_keep_days) { // && day_diff <= delta_keep_days) {
// return DownpourCtrFeatureValue::show(value) > global_cache_threshold; // return DownpourCtrFeatureValue::Show(value) > global_cache_threshold;
// } // }
// return false; // return false;
// } // }
bool DownpourCtrAccessor::save(float* value, int param) { bool DownpourCtrAccessor::Save(float* value, int param) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -206,9 +197,9 @@ bool DownpourCtrAccessor::save(float* value, int param) { ...@@ -206,9 +197,9 @@ bool DownpourCtrAccessor::save(float* value, int param) {
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
auto show_right = auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
auto click_right = auto click_right =
DownpourCtrFeatureValue::click(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold && if (show_click_score(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold && DownpourCtrFeatureValue::delta_score(value) >= delta_threshold &&
...@@ -224,8 +215,8 @@ bool DownpourCtrAccessor::save(float* value, int param) { ...@@ -224,8 +215,8 @@ bool DownpourCtrAccessor::save(float* value, int param) {
} }
// already decayed in shrink // already decayed in shrink
case 3: { case 3: {
// DownpourCtrFeatureValue::show(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::click(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
// DownpourCtrFeatureValue::unseen_days(value)++; // DownpourCtrFeatureValue::unseen_days(value)++;
return true; return true;
...@@ -235,7 +226,7 @@ bool DownpourCtrAccessor::save(float* value, int param) { ...@@ -235,7 +226,7 @@ bool DownpourCtrAccessor::save(float* value, int param) {
}; };
} }
void DownpourCtrAccessor::update_stat_after_save(float* value, int param) { void DownpourCtrAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -247,9 +238,9 @@ void DownpourCtrAccessor::update_stat_after_save(float* value, int param) { ...@@ -247,9 +238,9 @@ void DownpourCtrAccessor::update_stat_after_save(float* value, int param) {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
auto show_right = auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
auto click_right = auto click_right =
DownpourCtrFeatureValue::click(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold && if (show_click_score(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold && DownpourCtrFeatureValue::delta_score(value) >= delta_threshold &&
...@@ -268,28 +259,28 @@ void DownpourCtrAccessor::update_stat_after_save(float* value, int param) { ...@@ -268,28 +259,28 @@ void DownpourCtrAccessor::update_stat_after_save(float* value, int param) {
}; };
} }
int32_t DownpourCtrAccessor::create(float** values, size_t num) { int32_t DownpourCtrAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[DownpourCtrFeatureValue::unseen_days_index()] = 0; value[DownpourCtrFeatureValue::unseen_days_index()] = 0;
value[DownpourCtrFeatureValue::delta_score_index()] = 0; value[DownpourCtrFeatureValue::delta_score_index()] = 0;
value[DownpourCtrFeatureValue::show_index()] = 0; value[DownpourCtrFeatureValue::ShowIndex()] = 0;
value[DownpourCtrFeatureValue::click_index()] = 0; value[DownpourCtrFeatureValue::ClickIndex()] = 0;
value[DownpourCtrFeatureValue::slot_index()] = -1; value[DownpourCtrFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->init_value(
value + DownpourCtrFeatureValue::embed_w_index(), value + DownpourCtrFeatureValue::Embed_W_Index(),
value + DownpourCtrFeatureValue::embed_g2sum_index(), true); value + DownpourCtrFeatureValue::embed_g2sum_index(), true);
_embedx_sgd_rule->init_value( _embedx_sgd_rule->init_value(
value + DownpourCtrFeatureValue::embedx_w_index(), value + DownpourCtrFeatureValue::Embedx_W_Index(),
value + DownpourCtrFeatureValue::embedx_g2sum_index()); value + DownpourCtrFeatureValue::embedx_g2sum_index());
} }
return 0; return 0;
} }
bool DownpourCtrAccessor::need_extend_mf(float* value) { bool DownpourCtrAccessor::NeedExtendMF(float* value) {
float show = value[DownpourCtrFeatureValue::show_index()]; float show = value[DownpourCtrFeatureValue::ShowIndex()];
float click = value[DownpourCtrFeatureValue::click_index()]; float click = value[DownpourCtrFeatureValue::ClickIndex()];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() // float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() + float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff(); click * _config.ctr_accessor_param().click_coeff();
...@@ -297,25 +288,25 @@ bool DownpourCtrAccessor::need_extend_mf(float* value) { ...@@ -297,25 +288,25 @@ bool DownpourCtrAccessor::need_extend_mf(float* value) {
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
bool DownpourCtrAccessor::has_mf(size_t size) { bool DownpourCtrAccessor::HasMF(size_t size) {
return size > DownpourCtrFeatureValue::embedx_g2sum_index(); return size > DownpourCtrFeatureValue::embedx_g2sum_index();
} }
// from DownpourCtrFeatureValue to DownpourCtrPullValue // from DownpourCtrFeatureValue to DownpourCtrPullValue
int32_t DownpourCtrAccessor::select(float** select_values, const float** values, int32_t DownpourCtrAccessor::Select(float** select_values, const float** values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item]; float* select_value = select_values[value_item];
float* value = const_cast<float*>(values[value_item]); float* value = const_cast<float*>(values[value_item]);
select_value[DownpourCtrPullValue::show_index()] = select_value[DownpourCtrPullValue::ShowIndex()] =
value[DownpourCtrFeatureValue::show_index()]; value[DownpourCtrFeatureValue::ShowIndex()];
select_value[DownpourCtrPullValue::click_index()] = select_value[DownpourCtrPullValue::ClickIndex()] =
value[DownpourCtrFeatureValue::click_index()]; value[DownpourCtrFeatureValue::ClickIndex()];
select_value[DownpourCtrPullValue::embed_w_index()] = select_value[DownpourCtrPullValue::Embed_W_Index()] =
value[DownpourCtrFeatureValue::embed_w_index()]; value[DownpourCtrFeatureValue::Embed_W_Index()];
memcpy(select_value + DownpourCtrPullValue::embedx_w_index(), memcpy(select_value + DownpourCtrPullValue::Embedx_W_Index(),
value + DownpourCtrFeatureValue::embedx_w_index(), value + DownpourCtrFeatureValue::Embedx_W_Index(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -324,16 +315,16 @@ int32_t DownpourCtrAccessor::select(float** select_values, const float** values, ...@@ -324,16 +315,16 @@ int32_t DownpourCtrAccessor::select(float** select_values, const float** values,
// from DownpourCtrPushValue to DownpourCtrPushValue // from DownpourCtrPushValue to DownpourCtrPushValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t DownpourCtrAccessor::merge(float** update_values, int32_t DownpourCtrAccessor::Merge(float** update_values,
const float** other_update_values, const float** other_update_values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
size_t total_dim = DownpourCtrPushValue::dim(embedx_dim); size_t total_dim = DownpourCtrPushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item]; const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) { for (auto i = 0u; i < total_dim; ++i) {
if (i != DownpourCtrPushValue::slot_index()) { if (i != DownpourCtrPushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
} }
...@@ -344,18 +335,18 @@ int32_t DownpourCtrAccessor::merge(float** update_values, ...@@ -344,18 +335,18 @@ int32_t DownpourCtrAccessor::merge(float** update_values,
// from DownpourCtrPushValue to DownpourCtrFeatureValue // from DownpourCtrPushValue to DownpourCtrFeatureValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t DownpourCtrAccessor::update(float** update_values, int32_t DownpourCtrAccessor::Update(float** update_values,
const float** push_values, size_t num) { const float** push_values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* push_value = push_values[value_item]; const float* push_value = push_values[value_item];
float push_show = push_value[DownpourCtrPushValue::show_index()]; float push_show = push_value[DownpourCtrPushValue::ShowIndex()];
float push_click = push_value[DownpourCtrPushValue::click_index()]; float push_click = push_value[DownpourCtrPushValue::ClickIndex()];
float slot = push_value[DownpourCtrPushValue::slot_index()]; float slot = push_value[DownpourCtrPushValue::SlotIndex()];
update_value[DownpourCtrFeatureValue::show_index()] += push_show; update_value[DownpourCtrFeatureValue::ShowIndex()] += push_show;
update_value[DownpourCtrFeatureValue::click_index()] += push_click; update_value[DownpourCtrFeatureValue::ClickIndex()] += push_click;
update_value[DownpourCtrFeatureValue::slot_index()] = slot; update_value[DownpourCtrFeatureValue::SlotIndex()] = slot;
update_value[DownpourCtrFeatureValue::delta_score_index()] += update_value[DownpourCtrFeatureValue::delta_score_index()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
...@@ -363,25 +354,25 @@ int32_t DownpourCtrAccessor::update(float** update_values, ...@@ -363,25 +354,25 @@ int32_t DownpourCtrAccessor::update(float** update_values,
// push_click * _config.ctr_accessor_param().click_coeff(); // push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrFeatureValue::unseen_days_index()] = 0; update_value[DownpourCtrFeatureValue::unseen_days_index()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->update_value(
update_value + DownpourCtrFeatureValue::embed_w_index(), update_value + DownpourCtrFeatureValue::Embed_W_Index(),
update_value + DownpourCtrFeatureValue::embed_g2sum_index(), update_value + DownpourCtrFeatureValue::embed_g2sum_index(),
push_value + DownpourCtrPushValue::embed_g_index(), push_show); push_value + DownpourCtrPushValue::Embed_G_Index(), push_show);
_embedx_sgd_rule->update_value( _embedx_sgd_rule->update_value(
update_value + DownpourCtrFeatureValue::embedx_w_index(), update_value + DownpourCtrFeatureValue::Embedx_W_Index(),
update_value + DownpourCtrFeatureValue::embedx_g2sum_index(), update_value + DownpourCtrFeatureValue::embedx_g2sum_index(),
push_value + DownpourCtrPushValue::embedx_g_index(), push_show); push_value + DownpourCtrPushValue::Embedx_G_Index(), push_show);
} }
return 0; return 0;
} }
bool DownpourCtrAccessor::create_value(int stage, const float* value) { bool DownpourCtrAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull // stage == 0, pull
// stage == 1, push // stage == 1, push
if (stage == 0) { if (stage == 0) {
return true; return true;
} else if (stage == 1) { } else if (stage == 1) {
auto show = DownpourCtrPushValue::show(const_cast<float*>(value)); auto show = DownpourCtrPushValue::Show(const_cast<float*>(value));
auto click = DownpourCtrPushValue::click(const_cast<float*>(value)); auto click = DownpourCtrPushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
...@@ -404,15 +395,14 @@ float DownpourCtrAccessor::show_click_score(float show, float click) { ...@@ -404,15 +395,14 @@ float DownpourCtrAccessor::show_click_score(float show, float click) {
return (show - click) * nonclk_coeff + click * click_coeff; return (show - click) * nonclk_coeff + click * click_coeff;
} }
std::string DownpourCtrAccessor::parse_to_string(const float* v, std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) {
int param_size) {
thread_local std::ostringstream os; thread_local std::ostringstream os;
os.clear(); os.clear();
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5] << " " << v[6]; << v[5] << " " << v[6];
auto show = DownpourCtrFeatureValue::show(const_cast<float*>(v)); auto show = DownpourCtrFeatureValue::Show(const_cast<float*>(v));
auto click = DownpourCtrFeatureValue::click(const_cast<float*>(v)); auto click = DownpourCtrFeatureValue::Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold() && param_size > 7) { if (score >= _config.embedx_threshold() && param_size > 7) {
os << " " << v[7]; os << " " << v[7];
...@@ -423,22 +413,21 @@ std::string DownpourCtrAccessor::parse_to_string(const float* v, ...@@ -423,22 +413,21 @@ std::string DownpourCtrAccessor::parse_to_string(const float* v,
return os.str(); return os.str();
} }
int DownpourCtrAccessor::parse_from_string(const std::string& str, int DownpourCtrAccessor::ParseFromString(const std::string& str, float* value) {
float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
float data_buff[dim()]; float data_buff[Dim()];
float* data_buff_ptr = data_buff; float* data_buff_ptr = data_buff;
_embedx_sgd_rule->init_value( _embedx_sgd_rule->init_value(
data_buff_ptr + DownpourCtrFeatureValue::embedx_w_index(), data_buff_ptr + DownpourCtrFeatureValue::Embedx_W_Index(),
data_buff_ptr + DownpourCtrFeatureValue::embedx_g2sum_index()); data_buff_ptr + DownpourCtrFeatureValue::embedx_g2sum_index());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr); auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len; CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
// no slot, embedx // no slot, embedx
int value_dim = dim(); int value_dim = Dim();
int embedx_g2sum_index = DownpourCtrFeatureValue::embedx_g2sum_index(); int embedx_g2sum_index = DownpourCtrFeatureValue::embedx_g2sum_index();
value[DownpourCtrFeatureValue::slot_index()] = -1; value[DownpourCtrFeatureValue::SlotIndex()] = -1;
// other case // other case
if (str_len == (value_dim - 1)) { if (str_len == (value_dim - 1)) {
memcpy(value, data_buff_ptr, (embedx_g2sum_index - 1) * sizeof(float)); memcpy(value, data_buff_ptr, (embedx_g2sum_index - 1) * sizeof(float));
...@@ -494,8 +483,8 @@ void DownpourCtrAccessor::update_time_decay(float* value, ...@@ -494,8 +483,8 @@ void DownpourCtrAccessor::update_time_decay(float* value,
if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) { if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) {
return; return;
} }
DownpourCtrFeatureValue::show(value) *= _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) *= _time_decay_rates[day_diff];
DownpourCtrFeatureValue::click(value) *= _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) *= _time_decay_rates[day_diff];
if (is_update_seen_day) { if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id; DownpourCtrFeatureValue::unseen_days(value) = _day_id;
} }
......
...@@ -42,32 +42,30 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -42,32 +42,30 @@ class DownpourCtrAccessor : public ValueAccessor {
std::vector<float> embedx_w; std::vector<float> embedx_w;
*/ */
static int dim(int embedx_dim) { return 8 + embedx_dim; } static int Dim(int embedx_dim) { return 8 + embedx_dim; }
static int dim_size(size_t dim, int embedx_dim) { return sizeof(float); } static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int unseen_days_index() { return 0; } static int unseen_days_index() { return 0; }
static int delta_score_index() { static int delta_score_index() {
return DownpourCtrFeatureValue::unseen_days_index() + 1; return DownpourCtrFeatureValue::unseen_days_index() + 1;
} }
static int show_index() { static int ShowIndex() {
return DownpourCtrFeatureValue::delta_score_index() + 1; return DownpourCtrFeatureValue::delta_score_index() + 1;
} }
static int click_index() { static int ClickIndex() { return DownpourCtrFeatureValue::ShowIndex() + 1; }
return DownpourCtrFeatureValue::show_index() + 1; static int Embed_W_Index() {
} return DownpourCtrFeatureValue::ClickIndex() + 1;
static int embed_w_index() {
return DownpourCtrFeatureValue::click_index() + 1;
} }
static int embed_g2sum_index() { static int embed_g2sum_index() {
return DownpourCtrFeatureValue::embed_w_index() + 1; return DownpourCtrFeatureValue::Embed_W_Index() + 1;
} }
static int slot_index() { static int SlotIndex() {
return DownpourCtrFeatureValue::embed_g2sum_index() + 1; return DownpourCtrFeatureValue::embed_g2sum_index() + 1;
} }
static int embedx_g2sum_index() { static int embedx_g2sum_index() {
return DownpourCtrFeatureValue::slot_index() + 1; return DownpourCtrFeatureValue::SlotIndex() + 1;
} }
static int embedx_w_index() { static int Embedx_W_Index() {
return DownpourCtrFeatureValue::embedx_g2sum_index() + 1; return DownpourCtrFeatureValue::embedx_g2sum_index() + 1;
} }
static float& unseen_days(float* val) { static float& unseen_days(float* val) {
...@@ -76,17 +74,17 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -76,17 +74,17 @@ class DownpourCtrAccessor : public ValueAccessor {
static float& delta_score(float* val) { static float& delta_score(float* val) {
return val[DownpourCtrFeatureValue::delta_score_index()]; return val[DownpourCtrFeatureValue::delta_score_index()];
} }
static float& show(float* val) { static float& Show(float* val) {
return val[DownpourCtrFeatureValue::show_index()]; return val[DownpourCtrFeatureValue::ShowIndex()];
} }
static float& click(float* val) { static float& Click(float* val) {
return val[DownpourCtrFeatureValue::click_index()]; return val[DownpourCtrFeatureValue::ClickIndex()];
} }
static float& slot(float* val) { static float& Slot(float* val) {
return val[DownpourCtrFeatureValue::slot_index()]; return val[DownpourCtrFeatureValue::SlotIndex()];
} }
static float& embed_w(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrFeatureValue::embed_w_index()]; return val[DownpourCtrFeatureValue::Embed_W_Index()];
} }
static float& embed_g2sum(float* val) { static float& embed_g2sum(float* val) {
return val[DownpourCtrFeatureValue::embed_g2sum_index()]; return val[DownpourCtrFeatureValue::embed_g2sum_index()];
...@@ -94,8 +92,8 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -94,8 +92,8 @@ class DownpourCtrAccessor : public ValueAccessor {
static float& embedx_g2sum(float* val) { static float& embedx_g2sum(float* val) {
return val[DownpourCtrFeatureValue::embedx_g2sum_index()]; return val[DownpourCtrFeatureValue::embedx_g2sum_index()];
} }
static float* embedx_w(float* val) { static float* EmbedxW(float* val) {
return (val + DownpourCtrFeatureValue::embedx_w_index()); return (val + DownpourCtrFeatureValue::Embedx_W_Index());
} }
}; };
...@@ -108,24 +106,24 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -108,24 +106,24 @@ class DownpourCtrAccessor : public ValueAccessor {
std::vector<float> embedx_g; std::vector<float> embedx_g;
*/ */
static int dim(int embedx_dim) { return 4 + embedx_dim; } static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int dim_size(int dim, int embedx_dim) { return sizeof(float); } static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int slot_index() { return 0; } static int SlotIndex() { return 0; }
static int show_index() { return DownpourCtrPushValue::slot_index() + 1; } static int ShowIndex() { return DownpourCtrPushValue::SlotIndex() + 1; }
static int click_index() { return DownpourCtrPushValue::show_index() + 1; } static int ClickIndex() { return DownpourCtrPushValue::ShowIndex() + 1; }
static int embed_g_index() { static int Embed_G_Index() {
return DownpourCtrPushValue::click_index() + 1; return DownpourCtrPushValue::ClickIndex() + 1;
} }
static int embedx_g_index() { static int Embedx_G_Index() {
return DownpourCtrPushValue::embed_g_index() + 1; return DownpourCtrPushValue::Embed_G_Index() + 1;
} }
static float& slot(float* val) { return val[0]; } static float& Slot(float* val) { return val[0]; }
static float& show(float* val) { return val[1]; } static float& Show(float* val) { return val[1]; }
static float& click(float* val) { return val[2]; } static float& Click(float* val) { return val[2]; }
static float& embed_g(float* val) { return val[3]; } static float& EmbedG(float* val) { return val[3]; }
static float* embedx_g(float* val) { return val + 4; } static float* EmbedxG(float* val) { return val + 4; }
}; };
struct DownpourCtrPullValue { struct DownpourCtrPullValue {
...@@ -136,95 +134,95 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -136,95 +134,95 @@ class DownpourCtrAccessor : public ValueAccessor {
std::vector<float> embedx_w; std::vector<float> embedx_w;
*/ */
static int dim(int embedx_dim) { return 3 + embedx_dim; } static int Dim(int embedx_dim) { return 3 + embedx_dim; }
static int dim_size(size_t dim) { return sizeof(float); } static int DimSize(size_t dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int show_index() { return 0; } static int ShowIndex() { return 0; }
static int click_index() { return 1; } static int ClickIndex() { return 1; }
static int embed_w_index() { return 2; } static int Embed_W_Index() { return 2; }
static int embedx_w_index() { return 3; } static int Embedx_W_Index() { return 3; }
static float& show(float* val) { static float& Show(float* val) {
return val[DownpourCtrPullValue::show_index()]; return val[DownpourCtrPullValue::ShowIndex()];
} }
static float& click(float* val) { static float& Click(float* val) {
return val[DownpourCtrPullValue::click_index()]; return val[DownpourCtrPullValue::ClickIndex()];
} }
static float& embed_w(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrPullValue::embed_w_index()]; return val[DownpourCtrPullValue::Embed_W_Index()];
} }
static float* embedx_w(float* val) { static float* EmbedxW(float* val) {
return val + DownpourCtrPullValue::embedx_w_index(); return val + DownpourCtrPullValue::Embedx_W_Index();
} }
}; };
DownpourCtrAccessor() {} DownpourCtrAccessor() {}
virtual ~DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {}
virtual int initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key); virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); size_t Dim();
// value各个维度的size // value各个维度的size
virtual size_t dim_size(size_t dim); size_t DimSize(size_t dim);
// value各维度相加总size // value各维度相加总size
virtual size_t size(); size_t Size();
// value中mf动态长度部分总size大小, sparse下生效 // value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size(); size_t MFSize();
// pull value维度 // pull value维度
virtual size_t select_dim(); size_t SelectDim();
// pull value各个维度的size // pull value各个维度的size
virtual size_t select_dim_size(size_t dim); size_t SelectDimSize(size_t dim);
// pull value各维度相加总size // pull value各维度相加总size
virtual size_t select_size(); size_t SelectSize();
// push value维度 // push value维度
virtual size_t update_dim(); size_t UpdateDim();
// push value各个维度的size // push value各个维度的size
virtual size_t update_dim_size(size_t dim); size_t UpdateDimSize(size_t dim);
// push value各维度相加总size // push value各维度相加总size
virtual size_t update_size(); size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool shrink(float* value); virtual bool Shrink(float* value);
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
virtual bool save_ssd(float* value); virtual bool save_ssd(float* value);
virtual bool need_extend_mf(float* value); virtual bool NeedExtendMF(float* value);
virtual bool has_mf(size_t size); virtual bool HasMF(size_t size);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
// param = 1, save delta feature // param = 1, save delta feature
// param = 3, save all feature with time decay // param = 3, save all feature with time decay
virtual bool save(float* value, int param) override; virtual bool Save(float* value, int param) override;
// update delta_score and unseen_days after save // update delta_score and unseen_days after save
virtual void update_stat_after_save(float* value, int param) override; virtual void UpdateStatAfterSave(float* value, int param) override;
// virtual bool save_cache(float* value, int param, double // virtual bool save_cache(float* value, int param, double
// global_cache_threshold) override; // global_cache_threshold) override;
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕 // 要求value的内存由外部调用者分配完毕
virtual int32_t create(float** value, size_t num); virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中 // 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values, virtual int32_t Select(float** select_values, const float** values,
size_t num); size_t num);
// 将update_values聚合到一起 // 将update_values聚合到一起
virtual int32_t merge(float** update_values, virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num); const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it); // virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values, virtual int32_t Update(float** values, const float** update_values,
size_t num); size_t num);
virtual std::string parse_to_string(const float* value, int param) override; virtual std::string ParseToString(const float* value, int param) override;
virtual int32_t parse_from_string(const std::string& str, float* v) override; virtual int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool create_value(int type, const float* value); virtual bool CreateValue(int type, const float* value);
//这个接口目前只用来取show //这个接口目前只用来取show
virtual float get_field(float* value, const std::string& name) override { virtual float GetField(float* value, const std::string& name) override {
CHECK(name == "show"); CHECK(name == "show");
if (name == "show") { if (name == "show") {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
auto show_right = auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
return (float)show_right; return (float)show_right;
} }
return 0.0; return 0.0;
......
...@@ -99,9 +99,9 @@ int32_t MemorySparseTable::load(const std::string& path, ...@@ -99,9 +99,9 @@ int32_t MemorySparseTable::load(const std::string& path,
channel_config.path = file_list[file_start_idx + i]; channel_config.path = file_list[file_start_idx + i];
VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
<< " into local shard " << i; << " into local shard " << i;
channel_config.converter = _value_accesor->converter(load_param).converter; channel_config.converter = _value_accesor->Converter(load_param).converter;
channel_config.deconverter = channel_config.deconverter =
_value_accesor->converter(load_param).deconverter; _value_accesor->Converter(load_param).deconverter;
bool is_read_failed = false; bool is_read_failed = false;
int retry_num = 0; int retry_num = 0;
...@@ -119,8 +119,7 @@ int32_t MemorySparseTable::load(const std::string& path, ...@@ -119,8 +119,7 @@ int32_t MemorySparseTable::load(const std::string& path,
uint64_t key = std::strtoul(line_data.data(), &end, 10); uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto& value = shard[key]; auto& value = shard[key];
value.resize(feature_value_size); value.resize(feature_value_size);
int parse_size = int parse_size = _value_accesor->ParseFromString(++end, value.data());
_value_accesor->parse_from_string(++end, value.data());
value.resize(parse_size); value.resize(parse_size);
// for debug // for debug
...@@ -196,8 +195,7 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path, ...@@ -196,8 +195,7 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
uint64_t key = std::strtoul(line_data.data(), &end, 10); uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto& value = shard[key]; auto& value = shard[key];
value.resize(feature_value_size); value.resize(feature_value_size);
int parse_size = int parse_size = _value_accesor->ParseFromString(++end, value.data());
_value_accesor->parse_from_string(++end, value.data());
value.resize(parse_size); value.resize(parse_size);
} }
file.close(); file.close();
...@@ -253,9 +251,9 @@ int32_t MemorySparseTable::save(const std::string& dirname, ...@@ -253,9 +251,9 @@ int32_t MemorySparseTable::save(const std::string& dirname,
paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(), paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(),
_shard_idx, file_start_idx + i); _shard_idx, file_start_idx + i);
} }
channel_config.converter = _value_accesor->converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter = channel_config.deconverter =
_value_accesor->converter(save_param).deconverter; _value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false; bool is_write_failed = false;
int feasign_size = 0; int feasign_size = 0;
int retry_num = 0; int retry_num = 0;
...@@ -268,8 +266,8 @@ int32_t MemorySparseTable::save(const std::string& dirname, ...@@ -268,8 +266,8 @@ int32_t MemorySparseTable::save(const std::string& dirname,
auto write_channel = auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto it = shard.begin(); it != shard.end(); ++it) { for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->parse_to_string( std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size()); it.value().data(), it.value().size());
if (0 != if (0 !=
write_channel->write_line(paddle::string::format_string( write_channel->write_line(paddle::string::format_string(
...@@ -302,7 +300,7 @@ int32_t MemorySparseTable::save(const std::string& dirname, ...@@ -302,7 +300,7 @@ int32_t MemorySparseTable::save(const std::string& dirname,
} while (is_write_failed); } while (is_write_failed);
feasign_size_all += feasign_size; feasign_size_all += feasign_size;
for (auto it = shard.begin(); it != shard.end(); ++it) { for (auto it = shard.begin(); it != shard.end(); ++it) {
_value_accesor->update_stat_after_save(it.value().data(), save_param); _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
} }
LOG(INFO) << "MemorySparseTable save prefix success, path: " LOG(INFO) << "MemorySparseTable save prefix success, path: "
<< channel_config.path; << channel_config.path;
...@@ -334,9 +332,9 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname, ...@@ -334,9 +332,9 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname,
std::ofstream os; std::ofstream os;
os.open(file_name); os.open(file_name);
for (auto it = shard.begin(); it != shard.end(); ++it) { for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->parse_to_string( std::string format_value =
it.value().data(), it.value().size()); _value_accesor->ParseToString(it.value().data(), it.value().size());
std::string out_line = paddle::string::format_string( std::string out_line = paddle::string::format_string(
"%lu %s\n", it.key(), format_value.c_str()); "%lu %s\n", it.key(), format_value.c_str());
// VLOG(2) << out_line.c_str(); // VLOG(2) << out_line.c_str();
...@@ -370,7 +368,7 @@ int64_t MemorySparseTable::local_mf_size() { ...@@ -370,7 +368,7 @@ int64_t MemorySparseTable::local_mf_size() {
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
for (auto it = local_shard.begin(); it != local_shard.end(); for (auto it = local_shard.begin(); it != local_shard.end();
++it) { ++it) {
if (_value_accesor->has_mf(it.value().size())) { if (_value_accesor->HasMF(it.value().size())) {
size_arr[shard_id] += 1; size_arr[shard_id] += 1;
} }
} }
...@@ -453,7 +451,7 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, ...@@ -453,7 +451,7 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
auto& feature_value = local_shard[key]; auto& feature_value = local_shard[key];
feature_value.resize(data_size); feature_value.resize(data_size);
float* data_ptr = feature_value.data(); float* data_ptr = feature_value.data();
_value_accesor->create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(data_ptr, data_buffer_ptr, memcpy(data_ptr, data_buffer_ptr,
data_size * sizeof(float)); data_size * sizeof(float));
} }
...@@ -467,7 +465,7 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, ...@@ -467,7 +465,7 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
} }
auto offset = keys[i].second; auto offset = keys[i].second;
float* select_data = pull_values + select_value_size * offset; float* select_data = pull_values + select_value_size * offset;
_value_accesor->select(&select_data, _value_accesor->Select(&select_data,
(const float**)&data_buffer_ptr, 1); (const float**)&data_buffer_ptr, 1);
} }
...@@ -484,8 +482,8 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, ...@@ -484,8 +482,8 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
const uint64_t* keys, size_t num) { const uint64_t* keys, size_t num) {
CostTimer timer("pscore_sparse_select_all"); CostTimer timer("pscore_sparse_select_all");
size_t value_size = _value_accesor->size() / sizeof(float); size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_size = _value_accesor->mf_size() / sizeof(float); size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
...@@ -514,7 +512,7 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, ...@@ -514,7 +512,7 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
auto& feature_value = local_shard[key]; auto& feature_value = local_shard[key];
feature_value.resize(data_size); feature_value.resize(data_size);
float* data_ptr = feature_value.data(); float* data_ptr = feature_value.data();
_value_accesor->create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float)); memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
ret = &feature_value; ret = &feature_value;
} else { } else {
...@@ -564,13 +562,13 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, ...@@ -564,13 +562,13 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
if (itr == local_shard.end()) { if (itr == local_shard.end()) {
if (FLAGS_pserver_enable_create_feasign_randomly && if (FLAGS_pserver_enable_create_feasign_randomly &&
!_value_accesor->create_value(1, update_data)) { !_value_accesor->CreateValue(1, update_data)) {
continue; continue;
} }
auto value_size = value_col - mf_value_col; auto value_size = value_col - mf_value_col;
auto& feature_value = local_shard[key]; auto& feature_value = local_shard[key];
feature_value.resize(value_size); feature_value.resize(value_size);
_value_accesor->create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(feature_value.data(), data_buffer_ptr, memcpy(feature_value.data(), data_buffer_ptr,
value_size * sizeof(float)); value_size * sizeof(float));
itr = local_shard.find(key); itr = local_shard.find(key);
...@@ -581,16 +579,16 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, ...@@ -581,16 +579,16 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
size_t value_size = feature_value.size(); size_t value_size = feature_value.size();
if (value_size == value_col) { // 已拓展到最大size, 则就地update if (value_size == value_col) { // 已拓展到最大size, 则就地update
_value_accesor->update(&value_data, &update_data, 1); _value_accesor->Update(&value_data, &update_data, 1);
} else { } else {
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr, value_data, value_size * sizeof(float)); memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
_value_accesor->update(&data_buffer_ptr, &update_data, 1); _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
if (_value_accesor->need_extend_mf(data_buffer)) { if (_value_accesor->NeedExtendMF(data_buffer)) {
feature_value.resize(value_col); feature_value.resize(value_col);
value_data = feature_value.data(); value_data = feature_value.data();
_value_accesor->create(&value_data, 1); _value_accesor->Create(&value_data, 1);
} }
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
} }
...@@ -641,13 +639,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, ...@@ -641,13 +639,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
if (itr == local_shard.end()) { if (itr == local_shard.end()) {
if (FLAGS_pserver_enable_create_feasign_randomly && if (FLAGS_pserver_enable_create_feasign_randomly &&
!_value_accesor->create_value(1, update_data)) { !_value_accesor->CreateValue(1, update_data)) {
continue; continue;
} }
auto value_size = value_col - mf_value_col; auto value_size = value_col - mf_value_col;
auto& feature_value = local_shard[key]; auto& feature_value = local_shard[key];
feature_value.resize(value_size); feature_value.resize(value_size);
_value_accesor->create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(feature_value.data(), data_buffer_ptr, memcpy(feature_value.data(), data_buffer_ptr,
value_size * sizeof(float)); value_size * sizeof(float));
itr = local_shard.find(key); itr = local_shard.find(key);
...@@ -656,15 +654,15 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, ...@@ -656,15 +654,15 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
float* value_data = feature_value.data(); float* value_data = feature_value.data();
size_t value_size = feature_value.size(); size_t value_size = feature_value.size();
if (value_size == value_col) { // 已拓展到最大size, 则就地update if (value_size == value_col) { // 已拓展到最大size, 则就地update
_value_accesor->update(&value_data, &update_data, 1); _value_accesor->Update(&value_data, &update_data, 1);
} else { } else {
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr, value_data, value_size * sizeof(float)); memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
_value_accesor->update(&data_buffer_ptr, &update_data, 1); _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
if (_value_accesor->need_extend_mf(data_buffer)) { if (_value_accesor->NeedExtendMF(data_buffer)) {
feature_value.resize(value_col); feature_value.resize(value_col);
value_data = feature_value.data(); value_data = feature_value.data();
_value_accesor->create(&value_data, 1); _value_accesor->Create(&value_data, 1);
} }
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
} }
...@@ -688,7 +686,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) { ...@@ -688,7 +686,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) {
// shrink // shrink
auto& shard = _local_shards[shard_id]; auto& shard = _local_shards[shard_id];
for (auto it = shard.begin(); it != shard.end();) { for (auto it = shard.begin(); it != shard.end();) {
if (_value_accesor->shrink(it.value().data())) { if (_value_accesor->Shrink(it.value().data())) {
it = shard.erase(it); it = shard.erase(it);
} else { } else {
++it; ++it;
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int SparseAccessor::initialize() { int SparseAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
...@@ -39,73 +39,72 @@ int SparseAccessor::initialize() { ...@@ -39,73 +39,72 @@ int SparseAccessor::initialize() {
} }
void SparseAccessor::SetTableInfo(AccessorInfo& info) { void SparseAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = Dim();
info.size = size(); info.size = Size();
info.select_dim = select_dim(); info.select_dim = SelectDim();
info.select_size = select_size(); info.select_size = SelectSize();
info.update_dim = update_dim(); info.update_dim = UpdateDim();
info.update_size = update_size(); info.update_size = UpdateSize();
info.mf_size = mf_size(); info.mf_size = MFSize();
info.fea_dim = fea_dim();
} }
size_t SparseAccessor::GetTableInfo(InfoKey key) { size_t SparseAccessor::GetTableInfo(InfoKey key) {
switch (key) { switch (key) {
case DIM: case DIM:
return dim(); return Dim();
case SIZE: case SIZE:
return size(); return Size();
case SELECT_DIM: case SELECT_DIM:
return select_dim(); return SelectDim();
case SELECT_SIZE: case SELECT_SIZE:
return select_size(); return SelectSize();
case UPDATE_DIM: case UPDATE_DIM:
return update_dim(); return UpdateDim();
case UPDATE_SIZE: case UPDATE_SIZE:
return update_size(); return UpdateSize();
case MF_SIZE: case MF_SIZE:
return mf_size(); return MFSize();
case FEA_DIM: default:
return fea_dim(); return 0;
} }
return 0; return 0;
} }
size_t SparseAccessor::dim() { return sparse_feature_value.dim(); } size_t SparseAccessor::Dim() { return sparse_feature_value.Dim(); }
size_t SparseAccessor::dim_size(size_t dim) { size_t SparseAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return sparse_feature_value.dim_size(dim, embedx_dim); return sparse_feature_value.DimSize(dim, embedx_dim);
} }
size_t SparseAccessor::size() { return sparse_feature_value.size(); } size_t SparseAccessor::Size() { return sparse_feature_value.Size(); }
size_t SparseAccessor::mf_size() { size_t SparseAccessor::MFSize() {
return (_config.embedx_dim() + sparse_feature_value.embedx_sgd_dim) * return (_config.embedx_dim() + sparse_feature_value.embedx_sgd_dim) *
sizeof(float); // embedx embedx_g2sum sizeof(float); // embedx embedx_g2sum
} }
// pull value // pull value
size_t SparseAccessor::select_dim() { size_t SparseAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 1 + embedx_dim; return 1 + embedx_dim;
} }
size_t SparseAccessor::select_dim_size(size_t dim) { return sizeof(float); } size_t SparseAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t SparseAccessor::select_size() { return select_dim() * sizeof(float); } size_t SparseAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value // push value
size_t SparseAccessor::update_dim() { size_t SparseAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim; return 4 + embedx_dim;
} }
size_t SparseAccessor::update_dim_size(size_t dim) { return sizeof(float); } size_t SparseAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t SparseAccessor::update_size() { return update_dim() * sizeof(float); } size_t SparseAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool SparseAccessor::shrink(float* value) { bool SparseAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delete_after_unseen_days = auto delete_after_unseen_days =
...@@ -113,12 +112,12 @@ bool SparseAccessor::shrink(float* value) { ...@@ -113,12 +112,12 @@ bool SparseAccessor::shrink(float* value) {
auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first // time_decay first
sparse_feature_value.show(value) *= _show_click_decay_rate; sparse_feature_value.Show(value) *= _show_click_decay_rate;
sparse_feature_value.click(value) *= _show_click_decay_rate; sparse_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after // shrink after
auto score = show_click_score(sparse_feature_value.show(value), auto score = show_click_score(sparse_feature_value.Show(value),
sparse_feature_value.click(value)); sparse_feature_value.Click(value));
auto unseen_days = sparse_feature_value.unseen_days(value); auto unseen_days = sparse_feature_value.unseen_days(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) { if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true; return true;
...@@ -126,7 +125,7 @@ bool SparseAccessor::shrink(float* value) { ...@@ -126,7 +125,7 @@ bool SparseAccessor::shrink(float* value) {
return false; return false;
} }
bool SparseAccessor::save(float* value, int param) { bool SparseAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -142,8 +141,8 @@ bool SparseAccessor::save(float* value, int param) { ...@@ -142,8 +141,8 @@ bool SparseAccessor::save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
if (show_click_score(sparse_feature_value.show(value), if (show_click_score(sparse_feature_value.Show(value),
sparse_feature_value.click(value)) >= sparse_feature_value.Click(value)) >=
base_threshold && base_threshold &&
sparse_feature_value.delta_score(value) >= delta_threshold && sparse_feature_value.delta_score(value) >= delta_threshold &&
sparse_feature_value.unseen_days(value) <= delta_keep_days) { sparse_feature_value.unseen_days(value) <= delta_keep_days) {
...@@ -171,7 +170,7 @@ bool SparseAccessor::save(float* value, int param) { ...@@ -171,7 +170,7 @@ bool SparseAccessor::save(float* value, int param) {
} }
} }
void SparseAccessor::update_stat_after_save(float* value, int param) { void SparseAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
...@@ -180,8 +179,8 @@ void SparseAccessor::update_stat_after_save(float* value, int param) { ...@@ -180,8 +179,8 @@ void SparseAccessor::update_stat_after_save(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
if (show_click_score(sparse_feature_value.show(value), if (show_click_score(sparse_feature_value.Show(value),
sparse_feature_value.click(value)) >= sparse_feature_value.Click(value)) >=
base_threshold && base_threshold &&
sparse_feature_value.delta_score(value) >= delta_threshold && sparse_feature_value.delta_score(value) >= delta_threshold &&
sparse_feature_value.unseen_days(value) <= delta_keep_days) { sparse_feature_value.unseen_days(value) <= delta_keep_days) {
...@@ -198,48 +197,48 @@ void SparseAccessor::update_stat_after_save(float* value, int param) { ...@@ -198,48 +197,48 @@ void SparseAccessor::update_stat_after_save(float* value, int param) {
} }
} }
int32_t SparseAccessor::create(float** values, size_t num) { int32_t SparseAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[sparse_feature_value.unseen_days_index()] = 0; value[sparse_feature_value.unseen_days_index()] = 0;
value[sparse_feature_value.delta_score_index()] = 0; value[sparse_feature_value.delta_score_index()] = 0;
value[sparse_feature_value.show_index()] = 0; value[sparse_feature_value.ShowIndex()] = 0;
value[sparse_feature_value.click_index()] = 0; value[sparse_feature_value.ClickIndex()] = 0;
value[sparse_feature_value.slot_index()] = -1; value[sparse_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->init_value(
value + sparse_feature_value.embed_w_index(), value + sparse_feature_value.Embed_W_Index(),
value + sparse_feature_value.embed_g2sum_index()); value + sparse_feature_value.embed_g2sum_index());
_embedx_sgd_rule->init_value( _embedx_sgd_rule->init_value(
value + sparse_feature_value.embedx_w_index(), value + sparse_feature_value.Embedx_W_Index(),
value + sparse_feature_value.embedx_g2sum_index(), false); value + sparse_feature_value.embedx_g2sum_index(), false);
} }
return 0; return 0;
} }
bool SparseAccessor::need_extend_mf(float* value) { bool SparseAccessor::NeedExtendMF(float* value) {
float show = value[sparse_feature_value.show_index()]; float show = value[sparse_feature_value.ShowIndex()];
float click = value[sparse_feature_value.click_index()]; float click = value[sparse_feature_value.ClickIndex()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() + float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff(); click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
bool SparseAccessor::has_mf(size_t size) { bool SparseAccessor::HasMF(size_t size) {
return size > sparse_feature_value.embedx_g2sum_index(); return size > sparse_feature_value.embedx_g2sum_index();
} }
// from SparseFeatureValue to SparsePullValue // from SparseFeatureValue to SparsePullValue
int32_t SparseAccessor::select(float** select_values, const float** values, int32_t SparseAccessor::Select(float** select_values, const float** values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item]; float* select_value = select_values[value_item];
const float* value = values[value_item]; const float* value = values[value_item];
select_value[SparsePullValue::embed_w_index()] = select_value[SparsePullValue::Embed_W_Index()] =
value[sparse_feature_value.embed_w_index()]; value[sparse_feature_value.Embed_W_Index()];
memcpy(select_value + SparsePullValue::embedx_w_index(), memcpy(select_value + SparsePullValue::Embedx_W_Index(),
value + sparse_feature_value.embedx_w_index(), value + sparse_feature_value.Embedx_W_Index(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -248,15 +247,15 @@ int32_t SparseAccessor::select(float** select_values, const float** values, ...@@ -248,15 +247,15 @@ int32_t SparseAccessor::select(float** select_values, const float** values,
// from SparsePushValue to SparsePushValue // from SparsePushValue to SparsePushValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t SparseAccessor::merge(float** update_values, int32_t SparseAccessor::Merge(float** update_values,
const float** other_update_values, size_t num) { const float** other_update_values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
size_t total_dim = SparsePushValue::dim(embedx_dim); size_t total_dim = SparsePushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item]; const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) { for (auto i = 0u; i < total_dim; ++i) {
if (i != SparsePushValue::slot_index()) { if (i != SparsePushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
} }
...@@ -267,43 +266,43 @@ int32_t SparseAccessor::merge(float** update_values, ...@@ -267,43 +266,43 @@ int32_t SparseAccessor::merge(float** update_values,
// from SparsePushValue to SparseFeatureValue // from SparsePushValue to SparseFeatureValue
// first dim: item // first dim: item
// second dim: field num // second dim: field num
int32_t SparseAccessor::update(float** update_values, const float** push_values, int32_t SparseAccessor::Update(float** update_values, const float** push_values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* push_value = push_values[value_item]; const float* push_value = push_values[value_item];
float push_show = push_value[SparsePushValue::show_index()]; float push_show = push_value[SparsePushValue::ShowIndex()];
float push_click = push_value[SparsePushValue::click_index()]; float push_click = push_value[SparsePushValue::ClickIndex()];
float slot = push_value[SparsePushValue::slot_index()]; float slot = push_value[SparsePushValue::SlotIndex()];
update_value[sparse_feature_value.show_index()] += push_show; update_value[sparse_feature_value.ShowIndex()] += push_show;
update_value[sparse_feature_value.click_index()] += push_click; update_value[sparse_feature_value.ClickIndex()] += push_click;
update_value[sparse_feature_value.slot_index()] = slot; update_value[sparse_feature_value.SlotIndex()] = slot;
update_value[sparse_feature_value.delta_score_index()] += update_value[sparse_feature_value.delta_score_index()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
update_value[sparse_feature_value.unseen_days_index()] = 0; update_value[sparse_feature_value.unseen_days_index()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->update_value(
update_value + sparse_feature_value.embed_w_index(), update_value + sparse_feature_value.Embed_W_Index(),
update_value + sparse_feature_value.embed_g2sum_index(), update_value + sparse_feature_value.embed_g2sum_index(),
push_value + SparsePushValue::embed_g_index()); push_value + SparsePushValue::Embed_G_Index());
_embedx_sgd_rule->update_value( _embedx_sgd_rule->update_value(
update_value + sparse_feature_value.embedx_w_index(), update_value + sparse_feature_value.Embedx_W_Index(),
update_value + sparse_feature_value.embedx_g2sum_index(), update_value + sparse_feature_value.embedx_g2sum_index(),
push_value + SparsePushValue::embedx_g_index()); push_value + SparsePushValue::Embedx_G_Index());
} }
return 0; return 0;
} }
bool SparseAccessor::create_value(int stage, const float* value) { bool SparseAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull // stage == 0, pull
// stage == 1, push // stage == 1, push
if (stage == 0) { if (stage == 0) {
return true; return true;
} else if (stage == 1) { } else if (stage == 1) {
// operation // operation
auto show = SparsePushValue::show(const_cast<float*>(value)); auto show = SparsePushValue::Show(const_cast<float*>(value));
auto click = SparsePushValue::click(const_cast<float*>(value)); auto click = SparsePushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
...@@ -324,34 +323,34 @@ float SparseAccessor::show_click_score(float show, float click) { ...@@ -324,34 +323,34 @@ float SparseAccessor::show_click_score(float show, float click) {
return (show - click) * nonclk_coeff + click * click_coeff; return (show - click) * nonclk_coeff + click * click_coeff;
} }
std::string SparseAccessor::parse_to_string(const float* v, int param) { std::string SparseAccessor::ParseToString(const float* v, int param) {
thread_local std::ostringstream os; thread_local std::ostringstream os;
os.clear(); os.clear();
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5]; << v[5];
for (int i = sparse_feature_value.embed_g2sum_index(); for (int i = sparse_feature_value.embed_g2sum_index();
i < sparse_feature_value.embedx_w_index(); i++) { i < sparse_feature_value.Embedx_W_Index(); i++) {
os << " " << v[i]; os << " " << v[i];
} }
auto show = sparse_feature_value.show(const_cast<float*>(v)); auto show = sparse_feature_value.Show(const_cast<float*>(v));
auto click = sparse_feature_value.click(const_cast<float*>(v)); auto click = sparse_feature_value.Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold() && if (score >= _config.embedx_threshold() &&
param > sparse_feature_value.embedx_w_index()) { param > sparse_feature_value.Embedx_W_Index()) {
for (auto i = sparse_feature_value.embedx_w_index(); for (auto i = sparse_feature_value.Embedx_W_Index();
i < sparse_feature_value.dim(); ++i) { i < sparse_feature_value.Dim(); ++i) {
os << " " << v[i]; os << " " << v[i];
} }
} }
return os.str(); return os.str();
} }
int SparseAccessor::parse_from_string(const std::string& str, float* value) { int SparseAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value( _embedx_sgd_rule->init_value(
value + sparse_feature_value.embedx_w_index(), value + sparse_feature_value.Embedx_W_Index(),
value + sparse_feature_value.embedx_g2sum_index()); value + sparse_feature_value.embedx_g2sum_index());
auto ret = paddle::string::str_to_float(str.data(), value); auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret; CHECK(ret >= 6) << "expect more than 6 real:" << ret;
......
...@@ -40,27 +40,27 @@ class SparseAccessor : public ValueAccessor { ...@@ -40,27 +40,27 @@ class SparseAccessor : public ValueAccessor {
std::<vector>float embedx_g2sum; std::<vector>float embedx_g2sum;
*/ */
int dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } int Dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
int dim_size(size_t dim, int embedx_dim) { return sizeof(float); } int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int size() { return dim() * sizeof(float); } int Size() { return Dim() * sizeof(float); }
int slot_index() { return 0; } int SlotIndex() { return 0; }
int unseen_days_index() { return slot_index() + 1; } int unseen_days_index() { return SlotIndex() + 1; }
int delta_score_index() { return unseen_days_index() + 1; } int delta_score_index() { return unseen_days_index() + 1; }
int show_index() { return delta_score_index() + 1; } int ShowIndex() { return delta_score_index() + 1; }
int click_index() { return show_index() + 1; } int ClickIndex() { return ShowIndex() + 1; }
int embed_w_index() { return click_index() + 1; } int Embed_W_Index() { return ClickIndex() + 1; }
int embed_g2sum_index() { return embed_w_index() + 1; } int embed_g2sum_index() { return Embed_W_Index() + 1; }
int embedx_w_index() { return embed_g2sum_index() + embed_sgd_dim; } int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; }
int embedx_g2sum_index() { return embedx_w_index() + embedx_dim; } int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; } float& unseen_days(float* val) { return val[unseen_days_index()]; }
float& delta_score(float* val) { return val[delta_score_index()]; } float& delta_score(float* val) { return val[delta_score_index()]; }
float& show(float* val) { return val[show_index()]; } float& Show(float* val) { return val[ShowIndex()]; }
float& click(float* val) { return val[click_index()]; } float& Click(float* val) { return val[ClickIndex()]; }
float& slot(float* val) { return val[slot_index()]; } float& Slot(float* val) { return val[SlotIndex()]; }
float& embed_w(float* val) { return val[embed_w_index()]; } float& EmbedW(float* val) { return val[Embed_W_Index()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& embedx_w(float* val) { return val[embedx_w_index()]; } float& EmbedxW(float* val) { return val[Embedx_W_Index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
int embed_sgd_dim; int embed_sgd_dim;
...@@ -77,29 +77,25 @@ class SparseAccessor : public ValueAccessor { ...@@ -77,29 +77,25 @@ class SparseAccessor : public ValueAccessor {
std::vector<float> embedx_g; std::vector<float> embedx_g;
*/ */
static int dim(int embedx_dim) { return 4 + embedx_dim; } static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int dim_size(int dim, int embedx_dim) { return sizeof(float); } static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int slot_index() { return 0; } static int SlotIndex() { return 0; }
static int show_index() { return SparsePushValue::slot_index() + 1; } static int ShowIndex() { return SparsePushValue::SlotIndex() + 1; }
static int click_index() { return SparsePushValue::show_index() + 1; } static int ClickIndex() { return SparsePushValue::ShowIndex() + 1; }
static int embed_g_index() { return SparsePushValue::click_index() + 1; } static int Embed_G_Index() { return SparsePushValue::ClickIndex() + 1; }
static int embedx_g_index() { return SparsePushValue::embed_g_index() + 1; } static int Embedx_G_Index() { return SparsePushValue::Embed_G_Index() + 1; }
static float& slot(float* val) { static float& Slot(float* val) { return val[SparsePushValue::SlotIndex()]; }
return val[SparsePushValue::slot_index()]; static float& Show(float* val) { return val[SparsePushValue::ShowIndex()]; }
static float& Click(float* val) {
return val[SparsePushValue::ClickIndex()];
} }
static float& show(float* val) { static float& EmbedG(float* val) {
return val[SparsePushValue::show_index()]; return val[SparsePushValue::Embed_G_Index()];
} }
static float& click(float* val) { static float* EmbedxG(float* val) {
return val[SparsePushValue::click_index()]; return val + SparsePushValue::Embedx_G_Index();
}
static float& embed_g(float* val) {
return val[SparsePushValue::embed_g_index()];
}
static float* embedx_g(float* val) {
return val + SparsePushValue::embedx_g_index();
} }
}; };
...@@ -109,82 +105,82 @@ class SparseAccessor : public ValueAccessor { ...@@ -109,82 +105,82 @@ class SparseAccessor : public ValueAccessor {
std::vector<float> embedx_w; std::vector<float> embedx_w;
*/ */
static int dim(int embedx_dim) { return 1 + embedx_dim; } static int Dim(int embedx_dim) { return 1 + embedx_dim; }
static int dim_size(size_t dim) { return sizeof(float); } static int DimSize(size_t dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int embed_w_index() { return 0; } static int Embed_W_Index() { return 0; }
static int embedx_w_index() { return 1; } static int Embedx_W_Index() { return 1; }
static float& embed_w(float* val) { static float& EmbedW(float* val) {
return val[SparsePullValue::embed_w_index()]; return val[SparsePullValue::Embed_W_Index()];
} }
static float* embedx_w(float* val) { static float* EmbedxW(float* val) {
return val + SparsePullValue::embedx_w_index(); return val + SparsePullValue::Embedx_W_Index();
} }
}; };
SparseAccessor() {} SparseAccessor() {}
virtual int initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key); virtual size_t GetTableInfo(InfoKey key);
virtual ~SparseAccessor() {} virtual ~SparseAccessor() {}
// value维度 // value维度
virtual size_t dim(); size_t Dim();
// value各个维度的size // value各个维度的size
virtual size_t dim_size(size_t dim); size_t DimSize(size_t dim);
// value各维度相加总size // value各维度相加总size
virtual size_t size(); size_t Size();
// value中mf动态长度部分总size大小, sparse下生效 // value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size(); size_t MFSize();
// pull value维度 // pull value维度
virtual size_t select_dim(); size_t SelectDim();
// pull value各个维度的size // pull value各个维度的size
virtual size_t select_dim_size(size_t dim); size_t SelectDimSize(size_t dim);
// pull value各维度相加总size // pull value各维度相加总size
virtual size_t select_size(); size_t SelectSize();
// push value维度 // push value维度
virtual size_t update_dim(); size_t UpdateDim();
// push value各个维度的size // push value各个维度的size
virtual size_t update_dim_size(size_t dim); size_t UpdateDimSize(size_t dim);
// push value各维度相加总size // push value各维度相加总size
virtual size_t update_size(); size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool shrink(float* value); virtual bool Shrink(float* value);
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
// virtual bool save_ssd(float* value); // virtual bool save_ssd(float* value);
virtual bool need_extend_mf(float* value); virtual bool NeedExtendMF(float* value);
virtual bool has_mf(size_t size); virtual bool HasMF(size_t size);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
// param = 1, save delta feature // param = 1, save delta feature
// param = 2, save xbox base feature // param = 2, save xbox base feature
bool save(float* value, int param) override; bool Save(float* value, int param) override;
// update delta_score and unseen_days after save // update delta_score and unseen_days after save
void update_stat_after_save(float* value, int param) override; void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕 // 要求value的内存由外部调用者分配完毕
virtual int32_t create(float** value, size_t num); virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中 // 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values, virtual int32_t Select(float** select_values, const float** values,
size_t num); size_t num);
// 将update_values聚合到一起 // 将update_values聚合到一起
virtual int32_t merge(float** update_values, virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num); const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it); // virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values, virtual int32_t Update(float** values, const float** update_values,
size_t num); size_t num);
std::string parse_to_string(const float* value, int param) override; std::string ParseToString(const float* value, int param) override;
int32_t parse_from_string(const std::string& str, float* v) override; int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool create_value(int type, const float* value); virtual bool CreateValue(int type, const float* value);
// 这个接口目前只用来取show // 这个接口目前只用来取show
float get_field(float* value, const std::string& name) override { float GetField(float* value, const std::string& name) override {
// CHECK(name == "show"); // CHECK(name == "show");
if (name == "show") { if (name == "show") {
return sparse_feature_value.show(value); return sparse_feature_value.Show(value);
} }
return 0.0; return 0.0;
} }
......
...@@ -97,7 +97,7 @@ int32_t Table::initialize_accessor() { ...@@ -97,7 +97,7 @@ int32_t Table::initialize_accessor() {
<< ", accessor_name:" << _config.accessor().accessor_class(); << ", accessor_name:" << _config.accessor().accessor_class();
return -1; return -1;
} }
if (accessor->configure(_config.accessor()) || accessor->initialize() != 0) { if (accessor->Configure(_config.accessor()) || accessor->Initialize() != 0) {
LOG(ERROR) << " accessor initialize failed, table_id:" << _config.table_id() LOG(ERROR) << " accessor initialize failed, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class(); << ", accessor_name:" << _config.accessor().accessor_class();
return -1; return -1;
......
...@@ -18,86 +18,70 @@ ...@@ -18,86 +18,70 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int CommMergeAccessor::initialize() { return 0; } int CommMergeAccessor::Initialize() { return 0; }
void CommMergeAccessor::SetTableInfo(AccessorInfo &info) { void CommMergeAccessor::SetTableInfo(AccessorInfo &info) {
info.dim = dim(); info.select_dim = SelectDim();
info.size = size(); info.select_size = SelectSize();
info.select_dim = select_dim(); info.update_dim = UpdateDim();
info.select_size = select_size(); info.update_size = UpdateSize();
info.update_dim = update_dim();
info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim(); info.fea_dim = fea_dim();
} }
size_t CommMergeAccessor::GetTableInfo(InfoKey key) { size_t CommMergeAccessor::GetTableInfo(InfoKey key) {
switch (key) { switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM: case SELECT_DIM:
return select_dim(); return SelectDim();
case SELECT_SIZE: case SELECT_SIZE:
return select_size(); return SelectSize();
case UPDATE_DIM: case UPDATE_DIM:
return update_dim(); return UpdateDim();
case UPDATE_SIZE: case UPDATE_SIZE:
return update_size(); return UpdateSize();
case MF_SIZE:
return mf_size();
case FEA_DIM: case FEA_DIM:
return fea_dim(); return fea_dim();
default:
return 0;
} }
return 0; return 0;
} }
// value 维度
size_t CommMergeAccessor::dim() { return 0; }
// value 各个维度的size
size_t CommMergeAccessor::dim_size(size_t dim) { return 0; }
// value 各维度相加总size
size_t CommMergeAccessor::size() { return 0; }
// pull value 维度 // pull value 维度
size_t CommMergeAccessor::select_dim() { return _config.embedx_dim(); } size_t CommMergeAccessor::SelectDim() { return _config.embedx_dim(); }
// pull value 各个维度的size // pull value 各个维度的size
size_t CommMergeAccessor::select_dim_size(size_t dim) { return sizeof(float); } size_t CommMergeAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
// pull value 各维度相加总size // pull value 各维度相加总size
size_t CommMergeAccessor::select_size() { return select_dim() * sizeof(float); } size_t CommMergeAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value 维度 // push value 维度
size_t CommMergeAccessor::update_dim() { return _config.embedx_dim(); } size_t CommMergeAccessor::UpdateDim() { return _config.embedx_dim(); }
// push value 各个维度的size // push value 各个维度的size
size_t CommMergeAccessor::update_dim_size(size_t dim) { return sizeof(float); } size_t CommMergeAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
// push value 各维度相加总size // push value 各维度相加总size
size_t CommMergeAccessor::update_size() { return update_dim() * sizeof(float); } size_t CommMergeAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
// 判断该value 是否进行shrink // 判断该value 是否进行shrink
bool CommMergeAccessor::shrink(float * /*value*/) { return false; } bool CommMergeAccessor::Shrink(float * /*value*/) { return false; }
// 判断该value 是否在save阶段dump, // 判断该value 是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
bool CommMergeAccessor::save(float * /*value*/, int /*param*/) { return true; } bool CommMergeAccessor::Save(float * /*value*/, int /*param*/) { return true; }
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
int32_t CommMergeAccessor::create(float **value, size_t num) { return 0; } int32_t CommMergeAccessor::Create(float **value, size_t num) { return 0; }
// 从values中选取到select_values中 // 从values中选取到select_values中
int32_t CommMergeAccessor::select(float **select_values, const float **values, int32_t CommMergeAccessor::Select(float **select_values, const float **values,
size_t num) { size_t num) {
return 0; return 0;
} }
// 将update_values聚合到一起 // 将update_values聚合到一起
int32_t CommMergeAccessor::merge(float **update_values, int32_t CommMergeAccessor::Merge(float **update_values,
const float **other_update_values, const float **other_update_values,
size_t num) { size_t num) {
Eigen::Map<Eigen::MatrixXf> u_mat(update_values[0], 1, num); Eigen::Map<Eigen::MatrixXf> u_mat(update_values[0], 1, num);
...@@ -109,12 +93,12 @@ int32_t CommMergeAccessor::merge(float **update_values, ...@@ -109,12 +93,12 @@ int32_t CommMergeAccessor::merge(float **update_values,
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// int32_t merge(float** update_values, iterator it); // int32_t merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
int32_t CommMergeAccessor::update(float **values, const float **update_values, int32_t CommMergeAccessor::Update(float **values, const float **update_values,
size_t num) { size_t num) {
return 0; return 0;
} }
int CommMergeAccessor::set_weight(float **values, const float **update_values, int CommMergeAccessor::SetWeight(float **values, const float **update_values,
size_t num) { size_t num) {
return 0; return 0;
} }
......
...@@ -29,53 +29,49 @@ class CommMergeAccessor : public ValueAccessor { ...@@ -29,53 +29,49 @@ class CommMergeAccessor : public ValueAccessor {
public: public:
CommMergeAccessor() {} CommMergeAccessor() {}
virtual ~CommMergeAccessor() {} virtual ~CommMergeAccessor() {}
virtual int initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo &info); virtual void SetTableInfo(AccessorInfo &info);
virtual size_t GetTableInfo(InfoKey key); virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim();
// value各个维度的size
virtual size_t dim_size(size_t dim);
// value各维度相加总size
virtual size_t size();
// pull value维度 // pull value维度
virtual size_t select_dim(); size_t SelectDim();
// pull value各个维度的size // pull value各个维度的size
virtual size_t select_dim_size(size_t dim); size_t SelectDimSize(size_t dim);
// pull value各维度相加总size // pull value各维度相加总size
virtual size_t select_size(); size_t SelectSize();
// push value维度 // push value维度
virtual size_t update_dim(); size_t UpdateDim();
// push value各个维度的size // push value各个维度的size
virtual size_t update_dim_size(size_t dim); size_t UpdateDimSize(size_t dim);
// push value各维度相加总size // push value各维度相加总size
virtual size_t update_size(); size_t UpdateSize();
size_t fea_dim() { return _config.fea_dim(); }
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool shrink(float * /*value*/); virtual bool Shrink(float * /*value*/);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool save(float * /*value*/, int /*param*/); virtual bool Save(float * /*value*/, int /*param*/);
// keys不存在时,为values生成随机值 // keys不存在时,为values生成随机值
virtual int32_t create(float **value, size_t num); virtual int32_t Create(float **value, size_t num);
// 从values中选取到select_values中 // 从values中选取到select_values中
virtual int32_t select(float **select_values, const float **values, virtual int32_t Select(float **select_values, const float **values,
size_t num); size_t num);
// 将update_values聚合到一起 // 将update_values聚合到一起
virtual int32_t merge(float **update_values, virtual int32_t Merge(float **update_values,
const float **other_update_values, size_t num); const float **other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key // 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it); // virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中 // 将update_values更新应用到values中
virtual int32_t update(float **values, const float **update_values, virtual int32_t Update(float **values, const float **update_values,
size_t num); size_t num);
virtual int set_weight(float **values, const float **update_values, virtual int SetWeight(float **values, const float **update_values,
size_t num); size_t num);
virtual std::string parse_to_string(const float *value, int param) { virtual std::string ParseToString(const float *value, int param) {
return ""; return "";
} }
virtual int parse_from_string(const std::string &str, float *v) { return 0; } virtual int ParseFromString(const std::string &str, float *v) { return 0; }
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -67,49 +67,49 @@ TableAccessorParameter gen_param() { ...@@ -67,49 +67,49 @@ TableAccessorParameter gen_param() {
TEST(downpour_feature_value_accessor_test, test_shrink) { TEST(downpour_feature_value_accessor_test, test_shrink) {
TableAccessorParameter parameter = gen_param(); TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor(); CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
VLOG(3) << "size of struct: " << acc->common_feature_value.embed_sgd_dim VLOG(3) << "size of struct: " << acc->common_feature_value.embed_sgd_dim
<< " " << acc->common_feature_value.embedx_dim << " " << " " << acc->common_feature_value.embedx_dim << " "
<< acc->common_feature_value.embedx_sgd_dim << " " << acc->common_feature_value.embedx_sgd_dim << " "
<< acc->common_feature_value.dim() << "\n"; << acc->common_feature_value.Dim() << "\n";
float* value = new float[acc->dim()]; float* value = new float[acc->Dim()];
for (auto i = 0u; i < acc->dim(); ++i) { for (auto i = 0u; i < acc->Dim(); ++i) {
value[i] = i * 1.0; value[i] = i * 1.0;
} }
ASSERT_TRUE(!acc->shrink(value)); ASSERT_TRUE(!acc->Shrink(value));
// set unseen_days too long // set unseen_days too long
value[1] = 1000; value[1] = 1000;
// set delta score too small // set delta score too small
value[2] = 0.001; value[2] = 0.001;
ASSERT_TRUE(acc->shrink(value)); ASSERT_TRUE(acc->Shrink(value));
} }
TEST(downpour_feature_value_accessor_test, test_save) { TEST(downpour_feature_value_accessor_test, test_save) {
TableAccessorParameter parameter = gen_param(); TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor(); CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
float* value = new float[acc->dim()]; float* value = new float[acc->Dim()];
for (auto i = 0u; i < acc->dim(); ++i) { for (auto i = 0u; i < acc->Dim(); ++i) {
value[i] = i * 1.0; value[i] = i * 1.0;
} }
// save all feature // save all feature
ASSERT_TRUE(acc->save(value, 0)); ASSERT_TRUE(acc->Save(value, 0));
// save delta feature // save delta feature
ASSERT_TRUE(acc->save(value, 1)); ASSERT_TRUE(acc->Save(value, 1));
// save base feature with time decay // save base feature with time decay
ASSERT_TRUE(acc->save(value, 2)); ASSERT_TRUE(acc->Save(value, 2));
VLOG(3) << "test_save:"; VLOG(3) << "test_save:";
for (auto i = 0u; i < acc->dim(); ++i) { for (auto i = 0u; i < acc->Dim(); ++i) {
VLOG(3) << value[i]; VLOG(3) << value[i];
} }
} }
...@@ -117,8 +117,8 @@ TEST(downpour_feature_value_accessor_test, test_save) { ...@@ -117,8 +117,8 @@ TEST(downpour_feature_value_accessor_test, test_save) {
TEST(downpour_feature_value_accessor_test, test_create) { TEST(downpour_feature_value_accessor_test, test_create) {
TableAccessorParameter parameter = gen_param(); TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor(); CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
const int field_size = 7 + 8; const int field_size = 7 + 8;
const int item_size = 10; const int item_size = 10;
...@@ -127,7 +127,7 @@ TEST(downpour_feature_value_accessor_test, test_create) { ...@@ -127,7 +127,7 @@ TEST(downpour_feature_value_accessor_test, test_create) {
for (auto i = 0u; i < item_size; ++i) { for (auto i = 0u; i < item_size; ++i) {
value[i] = new float[field_size]; value[i] = new float[field_size];
} }
ASSERT_EQ(acc->create(value, item_size), 0); ASSERT_EQ(acc->Create(value, item_size), 0);
for (auto i = 0u; i < item_size; ++i) { for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < field_size; ++j) { for (auto j = 0u; j < field_size; ++j) {
...@@ -141,11 +141,11 @@ TEST(downpour_feature_value_accessor_test, test_create) { ...@@ -141,11 +141,11 @@ TEST(downpour_feature_value_accessor_test, test_create) {
TEST(downpour_feature_value_accessor_test, test_update) { TEST(downpour_feature_value_accessor_test, test_update) {
TableAccessorParameter parameter = gen_param(); TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor(); CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
VLOG(3) << "dim: " << acc->common_feature_value.dim() << "\n"; VLOG(3) << "dim: " << acc->common_feature_value.Dim() << "\n";
VLOG(3) << "update_dim: " << acc->update_dim() << "\n"; VLOG(3) << "update_dim: " << acc->GetTableInfo(UPDATE_DIM) << "\n";
const int field_size = 7 + 8; const int field_size = 7 + 8;
const int item_size = 10; const int item_size = 10;
...@@ -162,8 +162,8 @@ TEST(downpour_feature_value_accessor_test, test_update) { ...@@ -162,8 +162,8 @@ TEST(downpour_feature_value_accessor_test, test_update) {
typedef const float* const_float_ptr; typedef const float* const_float_ptr;
const_float_ptr* grad = new const_float_ptr[item_size]; const_float_ptr* grad = new const_float_ptr[item_size];
for (auto i = 0u; i < item_size; ++i) { for (auto i = 0u; i < item_size; ++i) {
float* p = new float[acc->update_dim()]; float* p = new float[acc->GetTableInfo(UPDATE_DIM)];
for (auto j = 0u; j < acc->update_dim(); ++j) { for (auto j = 0u; j < acc->GetTableInfo(UPDATE_DIM); ++j) {
p[j] = i; p[j] = i;
} }
grad[i] = p; grad[i] = p;
...@@ -251,14 +251,14 @@ TEST(downpour_feature_value_accessor_test, test_update) { ...@@ -251,14 +251,14 @@ TEST(downpour_feature_value_accessor_test, test_update) {
acc->_embedx_sgd_rule->update_value(&v.embedx_w[0], &v.embedx_g2sum[0], acc->_embedx_sgd_rule->update_value(&v.embedx_w[0], &v.embedx_g2sum[0],
&push_v.embedx_g[0]); &push_v.embedx_g[0]);
float* ptr = new float[acc->dim()]; float* ptr = new float[acc->Dim()];
v.to_array(ptr, parameter.embedx_dim()); v.to_array(ptr, parameter.embedx_dim());
exp_value.push_back(ptr); exp_value.push_back(ptr);
} }
acc->update(value, grad, item_size); acc->Update(value, grad, item_size);
for (auto i = 0u; i < item_size; ++i) { for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < acc->dim(); ++j) { for (auto j = 0u; j < acc->Dim(); ++j) {
VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " "; VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " ";
ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]); ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]);
} }
...@@ -268,8 +268,8 @@ TEST(downpour_feature_value_accessor_test, test_update) { ...@@ -268,8 +268,8 @@ TEST(downpour_feature_value_accessor_test, test_update) {
TEST(downpour_feature_value_accessor_test, test_show_click_score) { TEST(downpour_feature_value_accessor_test, test_show_click_score) {
TableAccessorParameter parameter = gen_param(); TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor(); CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
float show = 10; float show = 10;
float click = 6; float click = 6;
...@@ -279,8 +279,8 @@ TEST(downpour_feature_value_accessor_test, test_show_click_score) { ...@@ -279,8 +279,8 @@ TEST(downpour_feature_value_accessor_test, test_show_click_score) {
TEST(downpour_feature_value_accessor_test, test_string_related) { TEST(downpour_feature_value_accessor_test, test_string_related) {
TableAccessorParameter parameter = gen_param(); TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor(); CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
const int field_size = 15; const int field_size = 15;
float* value = new float[field_size]; float* value = new float[field_size];
...@@ -288,12 +288,12 @@ TEST(downpour_feature_value_accessor_test, test_string_related) { ...@@ -288,12 +288,12 @@ TEST(downpour_feature_value_accessor_test, test_string_related) {
value[i] = i; value[i] = i;
} }
auto str = acc->parse_to_string(value, 0); auto str = acc->ParseToString(value, 0);
VLOG(3) << str << std::endl; VLOG(3) << str << std::endl;
str = "0 1 2 3 4 5 6"; str = "0 1 2 3 4 5 6";
ASSERT_NE(acc->parse_from_string(str, value), 0); ASSERT_NE(acc->ParseFromString(str, value), 0);
// make sure init_zero=true // make sure init_zero=true
for (auto i = 7; i < 15; ++i) { for (auto i = 7; i < 15; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册