未验证 提交 f4826137 编写于 作者: W wangguanqun 提交者: GitHub

Delete function in accessor and update function name in accessor and sgd (#41292)

* delete function

* fix bug

* update name

* fix bug in strategy
上级 a9d66025
......@@ -525,12 +525,12 @@ std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id,
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM));
values->resize(shard_nums * accessor->GetAccessorInfo().update_dim);
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward(
(void *)(values->data()), // NOLINT
shard_nums * accessor->GetTableInfo(UPDATE_SIZE));
shard_nums * accessor->GetAccessorInfo().update_size);
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
......@@ -573,7 +573,7 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
......@@ -581,14 +581,13 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
......@@ -603,11 +602,9 @@ std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num,
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = GetTableAccessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM);
auto select_size = accessor->GetTableInfo(SELECT_SIZE);
auto fea_dim = accessor->GetAccessorInfo().fea_dim;
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num,
......@@ -617,7 +614,7 @@ std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num,
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size =
num_per_shard * accessor->GetTableInfo(SELECT_SIZE);
num_per_shard * accessor->GetAccessorInfo().select_size;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1;
......@@ -681,12 +678,13 @@ std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = GetTableAccessor(table_id);
auto accessor_info = accessor->GetAccessorInfo();
size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE);
DenseDimPerShard(accessor_info.fea_dim, request_call_num);
size_t shard_data_size = num_per_shard * accessor_info.update_size;
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) {
......@@ -793,7 +791,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
......@@ -802,15 +800,14 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
......@@ -831,7 +828,7 @@ std::future<int32_t> BrpcPsClient::PushDenseRawGradient(
std::future<int> fut = promise->get_future();
auto *accessor = GetTableAccessor(table_id);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
......@@ -910,7 +907,7 @@ std::future<int32_t> BrpcPsClient::PullSparse(float **select_values,
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
......@@ -1023,8 +1020,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
}
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
......@@ -1147,7 +1143,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) {
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
size_t value_size = accessor->GetAccessorInfo().update_size;
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
......@@ -1307,7 +1303,7 @@ std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id,
shard_kv_data.kv_num = 0;
continue;
}
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
uint32_t value_size = accessor->GetAccessorInfo().update_size;
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign(
......@@ -1453,7 +1449,7 @@ void BrpcPsClient::PushSparseTaskConsume() {
void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
const float *another_data) {
size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
size_t col_num = accessor->GetAccessorInfo().update_dim;
float *merge_data_shell[col_num];
const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) {
......@@ -1469,7 +1465,7 @@ int BrpcPsClient::PushSparseAsyncShardMerge(
ValueAccessor *accessor) {
size_t merged_kv_count = 0;
uint64_t min_key = UINT64_MAX;
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
uint32_t value_size = accessor->GetAccessorInfo().update_size;
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear();
......@@ -1575,9 +1571,8 @@ int BrpcPsClient::PushSparseAsyncShardPush(
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
int update_size = accessor->GetTableInfo(UPDATE_SIZE);
push_data->resize(merged_kv_count *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
int update_size = accessor->GetAccessorInfo().update_size;
push_data->resize(merged_kv_count * (sizeof(uint64_t) + update_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t));
......@@ -1586,8 +1581,8 @@ int BrpcPsClient::PushSparseAsyncShardPush(
const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
update_size);
push_data_ptr += update_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
......@@ -1602,8 +1597,8 @@ std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = GetTableAccessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM);
int update_dim = accessor->GetTableInfo(UPDATE_DIM);
int fea_dim = accessor->GetAccessorInfo().fea_dim;
int update_dim = accessor->GetAccessorInfo().update_dim;
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
......@@ -1621,13 +1616,9 @@ std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
auto dense_data = std::make_shared<std::vector<float>>();
auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num *
accessor->GetTableInfo(UPDATE_DIM));
async_task->data()->resize(num_per_shard * request_call_num * update_dim);
float *data = async_task->data()->data();
size_t data_size = async_task->data()->size();
uint32_t pos = 0;
......@@ -1757,7 +1748,7 @@ void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task,
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) {
......
......@@ -205,7 +205,7 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
}
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) /
res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
sizeof(float));
TableContext table_context;
......@@ -384,7 +384,7 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM);
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size);
......
......@@ -99,7 +99,8 @@ int32_t PsLocalClient::Initialize() {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
......@@ -145,8 +146,8 @@ int32_t PsLocalClient::Initialize() {
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0);
region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
......@@ -179,8 +180,8 @@ int32_t PsLocalClient::Initialize() {
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1));
region_buffer.resize(
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
......
......@@ -46,27 +46,24 @@ struct DataConverter {
};
struct AccessorInfo {
// value维度
size_t dim;
// value各个维度的size
size_t size;
size_t select_size;
// pull value维度
size_t select_dim;
size_t update_size;
// pull value各维度相加总size
size_t select_size;
// push value维度
size_t update_dim;
// push value各个维度的size
size_t update_size;
// value中mf动态长度部分总size大小, sparse下生效
size_t mf_size;
// value总维度,dense下生效
size_t fea_dim;
};
enum InfoKey {
DIM = 0,
SIZE = 1,
SELECT_SIZE = 2,
SELECT_DIM = 3,
UPDATE_SIZE = 4,
UPDATE_DIM = 5,
MF_SIZE = 6,
FEA_DIM = 7
};
class ValueAccessor {
public:
ValueAccessor() {}
......@@ -90,8 +87,7 @@ class ValueAccessor {
}
virtual int Initialize() = 0;
virtual void SetTableInfo(AccessorInfo& info) = 0;
virtual size_t GetTableInfo(InfoKey key) = 0;
virtual AccessorInfo GetAccessorInfo() { return _accessor_info; }
virtual bool NeedExtendMF(float* value) { return false; }
virtual bool HasMF(size_t size) { return false; }
......
......@@ -220,7 +220,8 @@ int32_t CommonDenseTable::Load(const std::string& path,
}
size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t dim_num_per_shard = _table_info.fea_dim / _shard_num + 1;
size_t dim_num_per_shard =
_value_accesor->GetAccessorInfo().fea_dim / _shard_num + 1;
size_t start_dim_idx = dim_num_per_shard * _shard_idx;
size_t start_file_idx = start_dim_idx / dim_num_per_file;
size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file;
......
......@@ -23,87 +23,35 @@ namespace distributed {
int CtrCommonAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
_config.embedx_dim());
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->dim();
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
InitAccessorInfo();
return 0;
}
void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = Dim();
info.size = Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t CtrCommonAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t CtrCommonAccessor::Dim() { return common_feature_value.Dim(); }
size_t CtrCommonAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return common_feature_value.DimSize(dim, embedx_dim);
}
void CtrCommonAccessor::InitAccessorInfo() {
_accessor_info.dim = common_feature_value.Dim();
_accessor_info.size = common_feature_value.Size();
size_t CtrCommonAccessor::Size() { return common_feature_value.Size(); }
size_t CtrCommonAccessor::MFSize() {
return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) *
sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t CtrCommonAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t CtrCommonAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value
size_t CtrCommonAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
}
size_t CtrCommonAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool CtrCommonAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......@@ -116,9 +64,9 @@ bool CtrCommonAccessor::Shrink(float* value) {
common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after
auto score = show_click_score(common_feature_value.Show(value),
common_feature_value.Click(value));
auto unseen_days = common_feature_value.unseen_days(value);
auto score = ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value));
auto unseen_days = common_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
......@@ -141,14 +89,13 @@ bool CtrCommonAccessor::Save(float* value, int param) {
case 1:
// save xbox base
case 2: {
if (show_click_score(common_feature_value.Show(value),
common_feature_value.Click(value)) >=
base_threshold &&
common_feature_value.delta_score(value) >= delta_threshold &&
common_feature_value.unseen_days(value) <= delta_keep_days) {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
common_feature_value.delta_score(value) = 0;
common_feature_value.DeltaScore(value) = 0;
}
return true;
} else {
......@@ -158,7 +105,7 @@ bool CtrCommonAccessor::Save(float* value, int param) {
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// common_feature_value.unseen_days(value)++;
// common_feature_value.UnseenDays(value)++;
return true;
}
// save revert batch_model
......@@ -179,17 +126,16 @@ void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) {
}
switch (param) {
case 1: {
if (show_click_score(common_feature_value.Show(value),
common_feature_value.Click(value)) >=
base_threshold &&
common_feature_value.delta_score(value) >= delta_threshold &&
common_feature_value.unseen_days(value) <= delta_keep_days) {
common_feature_value.delta_score(value) = 0;
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.DeltaScore(value) = 0;
}
}
return;
case 3: {
common_feature_value.unseen_days(value)++;
common_feature_value.UnseenDays(value)++;
}
return;
default:
......@@ -201,17 +147,16 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[common_feature_value.unseen_days_index()] = 0;
value[common_feature_value.delta_score_index()] = 0;
value[common_feature_value.UnseenDaysIndex()] = 0;
value[common_feature_value.DeltaScoreIndex()] = 0;
value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->init_value(
value + common_feature_value.Embed_W_Index(),
value + common_feature_value.embed_g2sum_index());
_embedx_sgd_rule->init_value(
value + common_feature_value.Embedx_W_Index(),
value + common_feature_value.embedx_g2sum_index(), false);
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
}
return 0;
}
......@@ -225,7 +170,7 @@ bool CtrCommonAccessor::NeedExtendMF(float* value) {
}
bool CtrCommonAccessor::HasMF(size_t size) {
return size > common_feature_value.embedx_g2sum_index();
return size > common_feature_value.EmbedxG2SumIndex();
}
// from CommonFeatureValue to CtrCommonPullValue
......@@ -239,10 +184,10 @@ int32_t CtrCommonAccessor::Select(float** select_values, const float** values,
value[common_feature_value.ShowIndex()];
select_value[CtrCommonPullValue::ClickIndex()] =
value[common_feature_value.ClickIndex()];
select_value[CtrCommonPullValue::Embed_W_Index()] =
value[common_feature_value.Embed_W_Index()];
memcpy(select_value + CtrCommonPullValue::Embedx_W_Index(),
value + common_feature_value.Embedx_W_Index(),
select_value[CtrCommonPullValue::EmbedWIndex()] =
value[common_feature_value.EmbedWIndex()];
memcpy(select_value + CtrCommonPullValue::EmbedxWIndex(),
value + common_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
......@@ -283,18 +228,18 @@ int32_t CtrCommonAccessor::Update(float** update_values,
update_value[common_feature_value.ShowIndex()] += push_show;
update_value[common_feature_value.ClickIndex()] += push_click;
update_value[common_feature_value.SlotIndex()] = slot;
update_value[common_feature_value.delta_score_index()] +=
update_value[common_feature_value.DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.unseen_days_index()] = 0;
_embed_sgd_rule->update_value(
update_value + common_feature_value.Embed_W_Index(),
update_value + common_feature_value.embed_g2sum_index(),
push_value + CtrCommonPushValue::Embed_G_Index());
_embedx_sgd_rule->update_value(
update_value + common_feature_value.Embedx_W_Index(),
update_value + common_feature_value.embedx_g2sum_index(),
push_value + CtrCommonPushValue::Embedx_G_Index());
update_value[common_feature_value.UnseenDaysIndex()] = 0;
_embed_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedWIndex(),
update_value + common_feature_value.EmbedG2SumIndex(),
push_value + CtrCommonPushValue::EmbedGIndex());
_embedx_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedxWIndex(),
update_value + common_feature_value.EmbedxG2SumIndex(),
push_value + CtrCommonPushValue::EmbedxGIndex());
}
return 0;
}
......@@ -308,7 +253,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
// operation
auto show = CtrCommonPushValue::Show(const_cast<float*>(value));
auto click = CtrCommonPushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
......@@ -322,7 +267,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
}
}
float CtrCommonAccessor::show_click_score(float show, float click) {
float CtrCommonAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
......@@ -334,16 +279,16 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5];
for (int i = common_feature_value.embed_g2sum_index();
i < common_feature_value.Embedx_W_Index(); i++) {
for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.EmbedxWIndex(); i++) {
os << " " << v[i];
}
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > common_feature_value.Embedx_W_Index()) {
for (auto i = common_feature_value.Embedx_W_Index();
param > common_feature_value.EmbedxWIndex()) {
for (auto i = common_feature_value.EmbedxWIndex();
i < common_feature_value.Dim(); ++i) {
os << " " << v[i];
}
......@@ -354,9 +299,8 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value(
value + common_feature_value.Embedx_W_Index(),
value + common_feature_value.embedx_g2sum_index());
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex());
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret;
......
......@@ -44,24 +44,24 @@ class CtrCommonAccessor : public ValueAccessor {
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); }
int SlotIndex() { return 0; }
int unseen_days_index() { return SlotIndex() + 1; }
int delta_score_index() { return unseen_days_index() + 1; }
int ShowIndex() { return delta_score_index() + 1; }
int UnseenDaysIndex() { return SlotIndex() + 1; }
int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; }
int Embed_W_Index() { return ClickIndex() + 1; }
int embed_g2sum_index() { return Embed_W_Index() + 1; }
int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; }
int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; }
int EmbedWIndex() { return ClickIndex() + 1; }
int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; }
int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; }
float& delta_score(float* val) { return val[delta_score_index()]; }
float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; }
float& EmbedW(float* val) { return val[Embed_W_Index()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& EmbedxW(float* val) { return val[Embedx_W_Index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
int embed_sgd_dim;
int embedx_dim;
......@@ -84,10 +84,8 @@ class CtrCommonAccessor : public ValueAccessor {
static int SlotIndex() { return 0; }
static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; }
static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; }
static int Embed_G_Index() { return CtrCommonPushValue::ClickIndex() + 1; }
static int Embedx_G_Index() {
return CtrCommonPushValue::Embed_G_Index() + 1;
}
static int EmbedGIndex() { return CtrCommonPushValue::ClickIndex() + 1; }
static int EmbedxGIndex() { return CtrCommonPushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) {
return val[CtrCommonPushValue::SlotIndex()];
}
......@@ -98,10 +96,10 @@ class CtrCommonAccessor : public ValueAccessor {
return val[CtrCommonPushValue::ClickIndex()];
}
static float& EmbedG(float* val) {
return val[CtrCommonPushValue::Embed_G_Index()];
return val[CtrCommonPushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + CtrCommonPushValue::Embedx_G_Index();
return val + CtrCommonPushValue::EmbedxGIndex();
}
};
......@@ -118,8 +116,8 @@ class CtrCommonAccessor : public ValueAccessor {
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int Embed_W_Index() { return 2; }
static int Embedx_W_Index() { return 3; }
static int EmbedWIndex() { return 2; }
static int EmbedxWIndex() { return 3; }
static float& Show(float* val) {
return val[CtrCommonPullValue::ShowIndex()];
}
......@@ -127,38 +125,17 @@ class CtrCommonAccessor : public ValueAccessor {
return val[CtrCommonPullValue::ClickIndex()];
}
static float& EmbedW(float* val) {
return val[CtrCommonPullValue::Embed_W_Index()];
return val[CtrCommonPullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + CtrCommonPullValue::Embedx_W_Index();
return val + CtrCommonPullValue::EmbedxWIndex();
}
};
CtrCommonAccessor() {}
virtual int Initialize();
virtual ~CtrCommonAccessor() {}
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
virtual int Initialize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
// 判断该value是否保存到ssd
......@@ -202,7 +179,7 @@ class CtrCommonAccessor : public ValueAccessor {
}
private:
// float show_click_score(float show, float click);
// float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
......@@ -213,7 +190,7 @@ class CtrCommonAccessor : public ValueAccessor {
public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
CtrCommonFeatureValue common_feature_value;
float show_click_score(float show, float click);
float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
};
......
......@@ -23,89 +23,32 @@ namespace distributed {
int DownpourCtrDoubleAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
_config.embedx_dim());
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
InitAccessorInfo();
return 0;
}
void DownpourCtrDoubleAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = Dim();
info.size = Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t DownpourCtrDoubleAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t DownpourCtrDoubleAccessor::Dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::Dim(embedx_dim);
}
size_t DownpourCtrDoubleAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::DimSize(dim, embedx_dim);
}
size_t DownpourCtrDoubleAccessor::Size() {
void DownpourCtrDoubleAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::Size(embedx_dim);
}
size_t DownpourCtrDoubleAccessor::MFSize() {
return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t DownpourCtrDoubleAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t DownpourCtrDoubleAccessor::SelectDimSize(size_t dim) {
return sizeof(float);
}
size_t DownpourCtrDoubleAccessor::SelectSize() {
return SelectDim() * sizeof(float);
}
// push value
size_t DownpourCtrDoubleAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t DownpourCtrDoubleAccessor::UpdateDimSize(size_t dim) {
return sizeof(float);
}
size_t DownpourCtrDoubleAccessor::UpdateSize() {
return UpdateDim() * sizeof(float);
_accessor_info.dim = DownpourCtrDoubleFeatureValue::Dim(embedx_dim);
_accessor_info.size = DownpourCtrDoubleFeatureValue::Size(embedx_dim);
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size = (embedx_dim + 1) * sizeof(float);
}
bool DownpourCtrDoubleAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......@@ -119,16 +62,16 @@ bool DownpourCtrDoubleAccessor::Shrink(float* value) {
DownpourCtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate;
DownpourCtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate;
// shrink after
auto score = show_click_score(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value));
auto unseen_days = DownpourCtrDoubleFeatureValue::unseen_days(value);
auto score = ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value));
auto unseen_days = DownpourCtrDoubleFeatureValue::UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool DownpourCtrDoubleAccessor::save_ssd(float* value) {
if (DownpourCtrDoubleFeatureValue::unseen_days(value) >
if (DownpourCtrDoubleFeatureValue::UnseenDays(value) >
_ssd_unseenday_threshold) {
return true;
}
......@@ -138,9 +81,9 @@ bool DownpourCtrDoubleAccessor::save_ssd(float* value) {
// float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value),
// if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
// DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold
// && DownpourCtrDoubleFeatureValue::unseen_days(value) <=
// && DownpourCtrDoubleFeatureValue::UnseenDays(value) <=
// delta_keep_days) {
// return DownpourCtrDoubleFeatureValue::Show(value) >
// global_cache_threshold;
......@@ -166,16 +109,14 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
case 1:
// save xbox base
case 2: {
if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >=
if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >=
base_threshold &&
DownpourCtrDoubleFeatureValue::delta_score(value) >=
delta_threshold &&
DownpourCtrDoubleFeatureValue::unseen_days(value) <=
delta_keep_days) {
DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
DownpourCtrDoubleFeatureValue::delta_score(value) = 0;
DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0;
}
return true;
} else {
......@@ -187,7 +128,7 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
// DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry
// DownpourCtrDoubleFeatureValue::unseen_days(value)++;
// DownpourCtrDoubleFeatureValue::UnseenDays(value)++;
return true;
}
default:
......@@ -204,19 +145,17 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
}
switch (param) {
case 1: {
if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >=
if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >=
base_threshold &&
DownpourCtrDoubleFeatureValue::delta_score(value) >=
delta_threshold &&
DownpourCtrDoubleFeatureValue::unseen_days(value) <=
delta_keep_days) {
DownpourCtrDoubleFeatureValue::delta_score(value) = 0;
DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0;
}
}
return;
case 3: {
DownpourCtrDoubleFeatureValue::unseen_days(value)++;
DownpourCtrDoubleFeatureValue::UnseenDays(value)++;
}
return;
default:
......@@ -228,17 +167,17 @@ int32_t DownpourCtrDoubleAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[DownpourCtrDoubleFeatureValue::unseen_days_index()] = 0;
value[DownpourCtrDoubleFeatureValue::delta_score_index()] = 0;
value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()) = 0;
value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->init_value(
value + DownpourCtrDoubleFeatureValue::Embed_W_Index(),
value + DownpourCtrDoubleFeatureValue::embed_g2sum_index());
_embedx_sgd_rule->init_value(
value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(),
value + DownpourCtrDoubleFeatureValue::embedx_g2sum_index(), false);
_embed_sgd_rule->InitValue(
value + DownpourCtrDoubleFeatureValue::EmbedWIndex(),
value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(
value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(), false);
}
return 0;
}
......@@ -264,10 +203,10 @@ int32_t DownpourCtrDoubleAccessor::Select(float** select_values,
(float)*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex());
select_value[DownpourCtrDoublePullValue::ClickIndex()] =
(float)*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex());
select_value[DownpourCtrDoublePullValue::Embed_W_Index()] =
value[DownpourCtrDoubleFeatureValue::Embed_W_Index()];
memcpy(select_value + DownpourCtrDoublePullValue::Embedx_W_Index(),
value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(),
select_value[DownpourCtrDoublePullValue::EmbedWIndex()] =
value[DownpourCtrDoubleFeatureValue::EmbedWIndex()];
memcpy(select_value + DownpourCtrDoublePullValue::EmbedxWIndex(),
value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
......@@ -316,20 +255,20 @@ int32_t DownpourCtrDoubleAccessor::Update(float** update_values,
*(double*)(update_value + DownpourCtrDoubleFeatureValue::ClickIndex()) +=
(double)push_click;
update_value[DownpourCtrDoubleFeatureValue::SlotIndex()] = slot;
update_value[DownpourCtrDoubleFeatureValue::delta_score_index()] +=
update_value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrDoubleFeatureValue::unseen_days_index()] = 0;
_embed_sgd_rule->update_value(
update_value + DownpourCtrDoubleFeatureValue::Embed_W_Index(),
update_value + DownpourCtrDoubleFeatureValue::embed_g2sum_index(),
push_value + DownpourCtrDoublePushValue::Embed_G_Index(), push_show);
_embedx_sgd_rule->update_value(
update_value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(),
update_value + DownpourCtrDoubleFeatureValue::embedx_g2sum_index(),
push_value + DownpourCtrDoublePushValue::Embedx_G_Index(), push_show);
update_value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
_embed_sgd_rule->UpdateValue(
update_value + DownpourCtrDoubleFeatureValue::EmbedWIndex(),
update_value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex(),
push_value + DownpourCtrDoublePushValue::EmbedGIndex(), push_show);
_embedx_sgd_rule->UpdateValue(
update_value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
update_value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(),
push_value + DownpourCtrDoublePushValue::EmbedxGIndex(), push_show);
}
return 0;
}
......@@ -341,7 +280,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) {
} else if (stage == 1) {
auto show = DownpourCtrDoublePushValue::Show(const_cast<float*>(value));
auto click = DownpourCtrDoublePushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
......@@ -354,7 +293,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) {
return true;
}
}
double DownpourCtrDoubleAccessor::show_click_score(double show, double click) {
double DownpourCtrDoubleAccessor::ShowClickScore(double show, double click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
......@@ -371,7 +310,7 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v,
<< v[8];
auto show = DownpourCtrDoubleFeatureValue::Show(const_cast<float*>(v));
auto click = DownpourCtrDoubleFeatureValue::Click(const_cast<float*>(v));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 9) {
os << " " << v[9];
for (auto i = 0; i < _config.embedx_dim(); ++i) {
......@@ -383,19 +322,19 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v,
int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str,
float* value) {
int embedx_dim = _config.embedx_dim();
float data_buff[Dim() + 2];
float data_buff[_accessor_info.dim + 2];
float* data_buff_ptr = data_buff;
_embedx_sgd_rule->init_value(
data_buff_ptr + DownpourCtrDoubleFeatureValue::Embedx_W_Index(),
data_buff_ptr + DownpourCtrDoubleFeatureValue::embedx_g2sum_index());
_embedx_sgd_rule->InitValue(
data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
int show_index = DownpourCtrDoubleFeatureValue::ShowIndex();
int click_index = DownpourCtrDoubleFeatureValue::ClickIndex();
int embed_w_index = DownpourCtrDoubleFeatureValue::Embed_W_Index();
int embed_w_index = DownpourCtrDoubleFeatureValue::EmbedWIndex();
// no slot, embedx
int value_dim = Dim();
int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::embedx_g2sum_index();
int value_dim = _accessor_info.dim;
int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex();
value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1;
// other case
if (str_len == (value_dim - 1)) {
......@@ -405,9 +344,8 @@ int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str,
*(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3];
// copy others
value[DownpourCtrDoubleFeatureValue::Embed_W_Index()] = data_buff_ptr[4];
value[DownpourCtrDoubleFeatureValue::embed_g2sum_index()] =
data_buff_ptr[5];
value[DownpourCtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4];
value[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5];
memcpy(value + embedx_g2sum_index, data_buff_ptr + 6,
(embedx_dim + 1) * sizeof(float));
} else {
......
......@@ -43,38 +43,38 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int Size(int embedx_dim) {
return (Dim(embedx_dim) + 2) * sizeof(float);
}
static int unseen_days_index() { return 0; }
static int delta_score_index() {
return DownpourCtrDoubleFeatureValue::unseen_days_index() + 1;
static int UnseenDaysIndex() { return 0; }
static int DeltaScoreIndex() {
return DownpourCtrDoubleFeatureValue::UnseenDaysIndex() + 1;
}
static int ShowIndex() {
return DownpourCtrDoubleFeatureValue::delta_score_index() + 1;
return DownpourCtrDoubleFeatureValue::DeltaScoreIndex() + 1;
}
// show is double
static int ClickIndex() {
return DownpourCtrDoubleFeatureValue::ShowIndex() + 2;
}
// click is double
static int Embed_W_Index() {
static int EmbedWIndex() {
return DownpourCtrDoubleFeatureValue::ClickIndex() + 2;
}
static int embed_g2sum_index() {
return DownpourCtrDoubleFeatureValue::Embed_W_Index() + 1;
static int EmbedG2SumIndex() {
return DownpourCtrDoubleFeatureValue::EmbedWIndex() + 1;
}
static int SlotIndex() {
return DownpourCtrDoubleFeatureValue::embed_g2sum_index() + 1;
return DownpourCtrDoubleFeatureValue::EmbedG2SumIndex() + 1;
}
static int embedx_g2sum_index() {
static int EmbedxG2SumIndex() {
return DownpourCtrDoubleFeatureValue::SlotIndex() + 1;
}
static int Embedx_W_Index() {
return DownpourCtrDoubleFeatureValue::embedx_g2sum_index() + 1;
static int EmbedxWIndex() {
return DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex() + 1;
}
static float& unseen_days(float* val) {
return val[DownpourCtrDoubleFeatureValue::unseen_days_index()];
static float& UnseenDays(float* val) {
return val[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()];
}
static float& delta_score(float* val) {
return val[DownpourCtrDoubleFeatureValue::delta_score_index()];
static float& DeltaScore(float* val) {
return val[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()];
}
static double& Show(float* val) {
return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0];
......@@ -86,16 +86,16 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
return val[DownpourCtrDoubleFeatureValue::SlotIndex()];
}
static float& EmbedW(float* val) {
return val[DownpourCtrDoubleFeatureValue::Embed_W_Index()];
return val[DownpourCtrDoubleFeatureValue::EmbedWIndex()];
}
static float& embed_g2sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::embed_g2sum_index()];
static float& EmbedG2Sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()];
}
static float& embedx_g2sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::embedx_g2sum_index()];
static float& EmbedxG2Sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex()];
}
static float* EmbedxW(float* val) {
return (val + DownpourCtrDoubleFeatureValue::Embedx_W_Index());
return (val + DownpourCtrDoubleFeatureValue::EmbedxWIndex());
}
};
struct DownpourCtrDoublePushValue {
......@@ -116,11 +116,11 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int ClickIndex() {
return DownpourCtrDoublePushValue::ShowIndex() + 1;
}
static int Embed_G_Index() {
static int EmbedGIndex() {
return DownpourCtrDoublePushValue::ClickIndex() + 1;
}
static int Embedx_G_Index() {
return DownpourCtrDoublePushValue::Embed_G_Index() + 1;
static int EmbedxGIndex() {
return DownpourCtrDoublePushValue::EmbedGIndex() + 1;
}
static float& Slot(float* val) {
return val[DownpourCtrDoublePushValue::SlotIndex()];
......@@ -132,10 +132,10 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
return val[DownpourCtrDoublePushValue::ClickIndex()];
}
static float& EmbedG(float* val) {
return val[DownpourCtrDoublePushValue::Embed_G_Index()];
return val[DownpourCtrDoublePushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + DownpourCtrDoublePushValue::Embedx_G_Index();
return val + DownpourCtrDoublePushValue::EmbedxGIndex();
}
};
struct DownpourCtrDoublePullValue {
......@@ -150,8 +150,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int Embed_W_Index() { return 2; }
static int Embedx_W_Index() { return 3; }
static int EmbedWIndex() { return 2; }
static int EmbedxWIndex() { return 3; }
static float& Show(float* val) {
return val[DownpourCtrDoublePullValue::ShowIndex()];
}
......@@ -159,37 +159,17 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
return val[DownpourCtrDoublePullValue::ClickIndex()];
}
static float& EmbedW(float* val) {
return val[DownpourCtrDoublePullValue::Embed_W_Index()];
return val[DownpourCtrDoublePullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + DownpourCtrDoublePullValue::Embedx_W_Index();
return val + DownpourCtrDoublePullValue::EmbedxWIndex();
}
};
DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {}
virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
virtual bool NeedExtendMF(float* value);
......@@ -235,7 +215,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embed_w)
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embedx_w)
private:
double show_click_score(double show, double click);
double ShowClickScore(double show, double click);
private:
SparseValueSGDRule* _embed_sgd_rule;
......
......@@ -23,91 +23,32 @@ namespace distributed {
int DownpourCtrAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
_config.embedx_dim());
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
set_time_decay_rates();
InitAccessorInfo();
return 0;
}
void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = Dim();
info.size = Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t DownpourCtrAccessor::Dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::Dim(embedx_dim);
}
size_t DownpourCtrAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::DimSize(dim, embedx_dim);
}
size_t DownpourCtrAccessor::Size() {
void DownpourCtrAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::Size(embedx_dim);
_accessor_info.dim = DownpourCtrFeatureValue::Dim(embedx_dim);
_accessor_info.size = DownpourCtrFeatureValue::Size(embedx_dim);
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size = (embedx_dim + 1) * sizeof(float);
}
size_t DownpourCtrAccessor::MFSize() {
return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t DownpourCtrAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t DownpourCtrAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t DownpourCtrAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value
size_t DownpourCtrAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t DownpourCtrAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t DownpourCtrAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool DownpourCtrAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......@@ -119,7 +60,7 @@ bool DownpourCtrAccessor::Shrink(float* value) {
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days;
if (day_diff < 0 || day_diff > delete_after_unseen_days) {
return true;
......@@ -130,7 +71,7 @@ bool DownpourCtrAccessor::Shrink(float* value) {
DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
// shrink after
auto score = show_click_score(show_right, click_right);
auto score = ShowClickScore(show_right, click_right);
if (score < delete_threshold) {
return true;
}
......@@ -145,7 +86,7 @@ bool DownpourCtrAccessor::save_ssd(float* value) {
if (_day_id == 0) {
return true;
}
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
if (unseen_days == 0) {
return false;
}
......@@ -164,9 +105,9 @@ bool DownpourCtrAccessor::save_ssd(float* value) {
// float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
// auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
// int16_t day_diff = _day_id - unseen_days;
// if (show_click_score(DownpourCtrFeatureValue::Show(value),
// if (ShowClickScore(DownpourCtrFeatureValue::Show(value),
// DownpourCtrFeatureValue::Click(value)) >= base_threshold
// && day_diff <= delta_keep_days) {
// return DownpourCtrFeatureValue::Show(value) > global_cache_threshold;
......@@ -193,7 +134,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) {
case 1:
// save xbox base
case 2: {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days;
auto show_right =
......@@ -201,12 +142,12 @@ bool DownpourCtrAccessor::Save(float* value, int param) {
auto click_right =
DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold &&
if (ShowClickScore(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::DeltaScore(value) >= delta_threshold &&
day_diff <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
DownpourCtrFeatureValue::delta_score(value) = 0;
DownpourCtrFeatureValue::DeltaScore(value) = 0;
}
return true;
} else {
......@@ -218,7 +159,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) {
// DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry
// DownpourCtrFeatureValue::unseen_days(value)++;
// DownpourCtrFeatureValue::UnseenDays(value)++;
return true;
}
default:
......@@ -235,23 +176,23 @@ void DownpourCtrAccessor::UpdateStatAfterSave(float* value, int param) {
}
switch (param) {
case 1: {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days;
auto show_right =
DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
auto click_right =
DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold &&
if (ShowClickScore(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::DeltaScore(value) >= delta_threshold &&
day_diff <= delta_keep_days) {
DownpourCtrFeatureValue::delta_score(value) = 0;
DownpourCtrFeatureValue::DeltaScore(value) = 0;
}
}
return;
// case 3:
// {
// DownpourCtrFeatureValue::unseen_days(value)++;
// DownpourCtrFeatureValue::UnseenDays(value)++;
// }
// return;
default:
......@@ -263,17 +204,17 @@ int32_t DownpourCtrAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[DownpourCtrFeatureValue::unseen_days_index()] = 0;
value[DownpourCtrFeatureValue::delta_score_index()] = 0;
value[DownpourCtrFeatureValue::UnseenDaysIndex()] = 0;
value[DownpourCtrFeatureValue::DeltaScoreIndex()] = 0;
value[DownpourCtrFeatureValue::ShowIndex()] = 0;
value[DownpourCtrFeatureValue::ClickIndex()] = 0;
value[DownpourCtrFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->init_value(
value + DownpourCtrFeatureValue::Embed_W_Index(),
value + DownpourCtrFeatureValue::embed_g2sum_index(), true);
_embedx_sgd_rule->init_value(
value + DownpourCtrFeatureValue::Embedx_W_Index(),
value + DownpourCtrFeatureValue::embedx_g2sum_index());
_embed_sgd_rule->InitValue(
value + DownpourCtrFeatureValue::EmbedWIndex(),
value + DownpourCtrFeatureValue::EmbedG2SumIndex(), true);
_embedx_sgd_rule->InitValue(
value + DownpourCtrFeatureValue::EmbedxWIndex(),
value + DownpourCtrFeatureValue::EmbedxG2SumIndex());
}
return 0;
}
......@@ -289,7 +230,7 @@ bool DownpourCtrAccessor::NeedExtendMF(float* value) {
}
bool DownpourCtrAccessor::HasMF(size_t size) {
return size > DownpourCtrFeatureValue::embedx_g2sum_index();
return size > DownpourCtrFeatureValue::EmbedxG2SumIndex();
}
// from DownpourCtrFeatureValue to DownpourCtrPullValue
......@@ -303,10 +244,10 @@ int32_t DownpourCtrAccessor::Select(float** select_values, const float** values,
value[DownpourCtrFeatureValue::ShowIndex()];
select_value[DownpourCtrPullValue::ClickIndex()] =
value[DownpourCtrFeatureValue::ClickIndex()];
select_value[DownpourCtrPullValue::Embed_W_Index()] =
value[DownpourCtrFeatureValue::Embed_W_Index()];
memcpy(select_value + DownpourCtrPullValue::Embedx_W_Index(),
value + DownpourCtrFeatureValue::Embedx_W_Index(),
select_value[DownpourCtrPullValue::EmbedWIndex()] =
value[DownpourCtrFeatureValue::EmbedWIndex()];
memcpy(select_value + DownpourCtrPullValue::EmbedxWIndex(),
value + DownpourCtrFeatureValue::EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
......@@ -347,20 +288,20 @@ int32_t DownpourCtrAccessor::Update(float** update_values,
update_value[DownpourCtrFeatureValue::ShowIndex()] += push_show;
update_value[DownpourCtrFeatureValue::ClickIndex()] += push_click;
update_value[DownpourCtrFeatureValue::SlotIndex()] = slot;
update_value[DownpourCtrFeatureValue::delta_score_index()] +=
update_value[DownpourCtrFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrFeatureValue::unseen_days_index()] = 0;
_embed_sgd_rule->update_value(
update_value + DownpourCtrFeatureValue::Embed_W_Index(),
update_value + DownpourCtrFeatureValue::embed_g2sum_index(),
push_value + DownpourCtrPushValue::Embed_G_Index(), push_show);
_embedx_sgd_rule->update_value(
update_value + DownpourCtrFeatureValue::Embedx_W_Index(),
update_value + DownpourCtrFeatureValue::embedx_g2sum_index(),
push_value + DownpourCtrPushValue::Embedx_G_Index(), push_show);
update_value[DownpourCtrFeatureValue::UnseenDaysIndex()] = 0;
_embed_sgd_rule->UpdateValue(
update_value + DownpourCtrFeatureValue::EmbedWIndex(),
update_value + DownpourCtrFeatureValue::EmbedG2SumIndex(),
push_value + DownpourCtrPushValue::EmbedGIndex(), push_show);
_embedx_sgd_rule->UpdateValue(
update_value + DownpourCtrFeatureValue::EmbedxWIndex(),
update_value + DownpourCtrFeatureValue::EmbedxG2SumIndex(),
push_value + DownpourCtrPushValue::EmbedxGIndex(), push_show);
}
return 0;
}
......@@ -373,7 +314,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) {
} else if (stage == 1) {
auto show = DownpourCtrPushValue::Show(const_cast<float*>(value));
auto click = DownpourCtrPushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
......@@ -387,7 +328,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) {
}
}
float DownpourCtrAccessor::show_click_score(float show, float click) {
float DownpourCtrAccessor::ShowClickScore(float show, float click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
......@@ -403,7 +344,7 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) {
<< v[5] << " " << v[6];
auto show = DownpourCtrFeatureValue::Show(const_cast<float*>(v));
auto click = DownpourCtrFeatureValue::Click(const_cast<float*>(v));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 7) {
os << " " << v[7];
for (auto i = 0; i < _config.embedx_dim(); ++i) {
......@@ -415,18 +356,18 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) {
int DownpourCtrAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim();
float data_buff[Dim()];
float data_buff[_accessor_info.dim];
float* data_buff_ptr = data_buff;
_embedx_sgd_rule->init_value(
data_buff_ptr + DownpourCtrFeatureValue::Embedx_W_Index(),
data_buff_ptr + DownpourCtrFeatureValue::embedx_g2sum_index());
_embedx_sgd_rule->InitValue(
data_buff_ptr + DownpourCtrFeatureValue::EmbedxWIndex(),
data_buff_ptr + DownpourCtrFeatureValue::EmbedxG2SumIndex());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
// no slot, embedx
int value_dim = Dim();
int embedx_g2sum_index = DownpourCtrFeatureValue::embedx_g2sum_index();
int value_dim = _accessor_info.dim;
int embedx_g2sum_index = DownpourCtrFeatureValue::EmbedxG2SumIndex();
value[DownpourCtrFeatureValue::SlotIndex()] = -1;
// other case
if (str_len == (value_dim - 1)) {
......@@ -459,25 +400,25 @@ void DownpourCtrAccessor::update_time_decay(float* value,
if (_day_id == 0) {
return;
}
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
if (unseen_days == 0) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
return;
}
// for the origin load (unseenday = 0 -15)
if (unseen_days < _config.ctr_accessor_param().delete_after_unseen_days()) {
// pull
if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
return;
// save 舍弃原始的unseenday,都变为上一天出现,保证show/click不被重复decay
} else {
DownpourCtrFeatureValue::unseen_days(value) = _day_id - 1;
DownpourCtrFeatureValue::UnseenDays(value) = _day_id - 1;
}
}
int16_t day_diff = _day_id - unseen_days;
if (day_diff < 0) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
return;
}
if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) {
......@@ -486,7 +427,7 @@ void DownpourCtrAccessor::update_time_decay(float* value,
DownpourCtrFeatureValue::Show(value) *= _time_decay_rates[day_diff];
DownpourCtrFeatureValue::Click(value) *= _time_decay_rates[day_diff];
if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
}
}
......
......@@ -45,34 +45,34 @@ class DownpourCtrAccessor : public ValueAccessor {
static int Dim(int embedx_dim) { return 8 + embedx_dim; }
static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int unseen_days_index() { return 0; }
static int delta_score_index() {
return DownpourCtrFeatureValue::unseen_days_index() + 1;
static int UnseenDaysIndex() { return 0; }
static int DeltaScoreIndex() {
return DownpourCtrFeatureValue::UnseenDaysIndex() + 1;
}
static int ShowIndex() {
return DownpourCtrFeatureValue::delta_score_index() + 1;
return DownpourCtrFeatureValue::DeltaScoreIndex() + 1;
}
static int ClickIndex() { return DownpourCtrFeatureValue::ShowIndex() + 1; }
static int Embed_W_Index() {
static int EmbedWIndex() {
return DownpourCtrFeatureValue::ClickIndex() + 1;
}
static int embed_g2sum_index() {
return DownpourCtrFeatureValue::Embed_W_Index() + 1;
static int EmbedG2SumIndex() {
return DownpourCtrFeatureValue::EmbedWIndex() + 1;
}
static int SlotIndex() {
return DownpourCtrFeatureValue::embed_g2sum_index() + 1;
return DownpourCtrFeatureValue::EmbedG2SumIndex() + 1;
}
static int embedx_g2sum_index() {
static int EmbedxG2SumIndex() {
return DownpourCtrFeatureValue::SlotIndex() + 1;
}
static int Embedx_W_Index() {
return DownpourCtrFeatureValue::embedx_g2sum_index() + 1;
static int EmbedxWIndex() {
return DownpourCtrFeatureValue::EmbedxG2SumIndex() + 1;
}
static float& unseen_days(float* val) {
return val[DownpourCtrFeatureValue::unseen_days_index()];
static float& UnseenDays(float* val) {
return val[DownpourCtrFeatureValue::UnseenDaysIndex()];
}
static float& delta_score(float* val) {
return val[DownpourCtrFeatureValue::delta_score_index()];
static float& DeltaScore(float* val) {
return val[DownpourCtrFeatureValue::DeltaScoreIndex()];
}
static float& Show(float* val) {
return val[DownpourCtrFeatureValue::ShowIndex()];
......@@ -84,16 +84,16 @@ class DownpourCtrAccessor : public ValueAccessor {
return val[DownpourCtrFeatureValue::SlotIndex()];
}
static float& EmbedW(float* val) {
return val[DownpourCtrFeatureValue::Embed_W_Index()];
return val[DownpourCtrFeatureValue::EmbedWIndex()];
}
static float& embed_g2sum(float* val) {
return val[DownpourCtrFeatureValue::embed_g2sum_index()];
static float& EmbedG2Sum(float* val) {
return val[DownpourCtrFeatureValue::EmbedG2SumIndex()];
}
static float& embedx_g2sum(float* val) {
return val[DownpourCtrFeatureValue::embedx_g2sum_index()];
static float& EmbedxG2Sum(float* val) {
return val[DownpourCtrFeatureValue::EmbedxG2SumIndex()];
}
static float* EmbedxW(float* val) {
return (val + DownpourCtrFeatureValue::Embedx_W_Index());
return (val + DownpourCtrFeatureValue::EmbedxWIndex());
}
};
......@@ -113,11 +113,9 @@ class DownpourCtrAccessor : public ValueAccessor {
static int SlotIndex() { return 0; }
static int ShowIndex() { return DownpourCtrPushValue::SlotIndex() + 1; }
static int ClickIndex() { return DownpourCtrPushValue::ShowIndex() + 1; }
static int Embed_G_Index() {
return DownpourCtrPushValue::ClickIndex() + 1;
}
static int Embedx_G_Index() {
return DownpourCtrPushValue::Embed_G_Index() + 1;
static int EmbedGIndex() { return DownpourCtrPushValue::ClickIndex() + 1; }
static int EmbedxGIndex() {
return DownpourCtrPushValue::EmbedGIndex() + 1;
}
static float& Slot(float* val) { return val[0]; }
static float& Show(float* val) { return val[1]; }
......@@ -139,8 +137,8 @@ class DownpourCtrAccessor : public ValueAccessor {
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int Embed_W_Index() { return 2; }
static int Embedx_W_Index() { return 3; }
static int EmbedWIndex() { return 2; }
static int EmbedxWIndex() { return 3; }
static float& Show(float* val) {
return val[DownpourCtrPullValue::ShowIndex()];
}
......@@ -148,38 +146,18 @@ class DownpourCtrAccessor : public ValueAccessor {
return val[DownpourCtrPullValue::ClickIndex()];
}
static float& EmbedW(float* val) {
return val[DownpourCtrPullValue::Embed_W_Index()];
return val[DownpourCtrPullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + DownpourCtrPullValue::Embedx_W_Index();
return val + DownpourCtrPullValue::EmbedxWIndex();
}
};
DownpourCtrAccessor() {}
virtual ~DownpourCtrAccessor() {}
virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
// 判断该value是否保存到ssd
......@@ -219,7 +197,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual float GetField(float* value, const std::string& name) override {
CHECK(name == "show");
if (name == "show") {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days;
auto show_right =
DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
......@@ -238,7 +216,7 @@ class DownpourCtrAccessor : public ValueAccessor {
bool test_func() { return false; }
private:
float show_click_score(float show, float click);
float ShowClickScore(float show, float click);
void set_time_decay_rates();
private:
......
......@@ -89,7 +89,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float);
_value_accesor->GetAccessorInfo().size / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num);
......@@ -174,7 +174,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float);
_value_accesor->GetAccessorInfo().size / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num);
......@@ -415,10 +415,12 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num);
const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
const size_t value_size =
_value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t select_value_size =
_value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float);
_value_accesor->GetAccessorInfo().select_size / sizeof(float);
// std::atomic<uint32_t> missed_keys{0};
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
......@@ -482,8 +484,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) {
CostTimer timer("pscore_sparse_select_all");
size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
......@@ -541,10 +544,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
task_keys[shard_id].push_back({keys[i], i});
}
const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
const size_t value_col =
_value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_col =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
_value_accesor->GetAccessorInfo().update_size / sizeof(float);
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
......@@ -619,10 +624,11 @@ int32_t MemorySparseTable::_PushSparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i});
}
size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_col =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
_value_accesor->GetAccessorInfo().update_size / sizeof(float);
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
......
......@@ -23,87 +23,35 @@ namespace distributed {
int SparseAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
_config.embedx_dim());
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->dim();
sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
sparse_feature_value.embedx_dim = _config.embedx_dim();
sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim();
sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
InitAccessorInfo();
return 0;
}
void SparseAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = Dim();
info.size = Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t SparseAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t SparseAccessor::Dim() { return sparse_feature_value.Dim(); }
size_t SparseAccessor::DimSize(size_t dim) {
void SparseAccessor::InitAccessorInfo() {
_accessor_info.dim = sparse_feature_value.Dim();
_accessor_info.size = sparse_feature_value.Size();
auto embedx_dim = _config.embedx_dim();
return sparse_feature_value.DimSize(dim, embedx_dim);
}
size_t SparseAccessor::Size() { return sparse_feature_value.Size(); }
size_t SparseAccessor::MFSize() {
return (_config.embedx_dim() + sparse_feature_value.embedx_sgd_dim) *
sizeof(float); // embedx embedx_g2sum
_accessor_info.select_dim = 1 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
;
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + sparse_feature_value.embedx_sgd_dim) * sizeof(float);
}
// pull value
size_t SparseAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 1 + embedx_dim;
}
size_t SparseAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t SparseAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value
size_t SparseAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t SparseAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t SparseAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool SparseAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......@@ -116,9 +64,9 @@ bool SparseAccessor::Shrink(float* value) {
sparse_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after
auto score = show_click_score(sparse_feature_value.Show(value),
sparse_feature_value.Click(value));
auto unseen_days = sparse_feature_value.unseen_days(value);
auto score = ShowClickScore(sparse_feature_value.Show(value),
sparse_feature_value.Click(value));
auto unseen_days = sparse_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
......@@ -141,14 +89,13 @@ bool SparseAccessor::Save(float* value, int param) {
case 1:
// save xbox base
case 2: {
if (show_click_score(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)) >=
base_threshold &&
sparse_feature_value.delta_score(value) >= delta_threshold &&
sparse_feature_value.unseen_days(value) <= delta_keep_days) {
if (ShowClickScore(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)) >= base_threshold &&
sparse_feature_value.DeltaScore(value) >= delta_threshold &&
sparse_feature_value.UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
sparse_feature_value.delta_score(value) = 0;
sparse_feature_value.DeltaScore(value) = 0;
}
return true;
} else {
......@@ -158,7 +105,7 @@ bool SparseAccessor::Save(float* value, int param) {
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// sparse_feature_value.unseen_days(value)++;
// sparse_feature_value.UnseenDays(value)++;
return true;
}
// save revert batch_model
......@@ -179,17 +126,16 @@ void SparseAccessor::UpdateStatAfterSave(float* value, int param) {
}
switch (param) {
case 1: {
if (show_click_score(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)) >=
base_threshold &&
sparse_feature_value.delta_score(value) >= delta_threshold &&
sparse_feature_value.unseen_days(value) <= delta_keep_days) {
sparse_feature_value.delta_score(value) = 0;
if (ShowClickScore(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)) >= base_threshold &&
sparse_feature_value.DeltaScore(value) >= delta_threshold &&
sparse_feature_value.UnseenDays(value) <= delta_keep_days) {
sparse_feature_value.DeltaScore(value) = 0;
}
}
return;
case 3: {
sparse_feature_value.unseen_days(value)++;
sparse_feature_value.UnseenDays(value)++;
}
return;
default:
......@@ -201,17 +147,16 @@ int32_t SparseAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[sparse_feature_value.unseen_days_index()] = 0;
value[sparse_feature_value.delta_score_index()] = 0;
value[sparse_feature_value.UnseenDaysIndex()] = 0;
value[sparse_feature_value.DeltaScoreIndex()] = 0;
value[sparse_feature_value.ShowIndex()] = 0;
value[sparse_feature_value.ClickIndex()] = 0;
value[sparse_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->init_value(
value + sparse_feature_value.Embed_W_Index(),
value + sparse_feature_value.embed_g2sum_index());
_embedx_sgd_rule->init_value(
value + sparse_feature_value.Embedx_W_Index(),
value + sparse_feature_value.embedx_g2sum_index(), false);
_embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(),
value + sparse_feature_value.EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
value + sparse_feature_value.EmbedxG2SumIndex(),
false);
}
return 0;
}
......@@ -225,7 +170,7 @@ bool SparseAccessor::NeedExtendMF(float* value) {
}
bool SparseAccessor::HasMF(size_t size) {
return size > sparse_feature_value.embedx_g2sum_index();
return size > sparse_feature_value.EmbedxG2SumIndex();
}
// from SparseFeatureValue to SparsePullValue
......@@ -235,10 +180,10 @@ int32_t SparseAccessor::Select(float** select_values, const float** values,
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
const float* value = values[value_item];
select_value[SparsePullValue::Embed_W_Index()] =
value[sparse_feature_value.Embed_W_Index()];
memcpy(select_value + SparsePullValue::Embedx_W_Index(),
value + sparse_feature_value.Embedx_W_Index(),
select_value[SparsePullValue::EmbedWIndex()] =
value[sparse_feature_value.EmbedWIndex()];
memcpy(select_value + SparsePullValue::EmbedxWIndex(),
value + sparse_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
......@@ -278,18 +223,18 @@ int32_t SparseAccessor::Update(float** update_values, const float** push_values,
update_value[sparse_feature_value.ShowIndex()] += push_show;
update_value[sparse_feature_value.ClickIndex()] += push_click;
update_value[sparse_feature_value.SlotIndex()] = slot;
update_value[sparse_feature_value.delta_score_index()] +=
update_value[sparse_feature_value.DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
update_value[sparse_feature_value.unseen_days_index()] = 0;
_embed_sgd_rule->update_value(
update_value + sparse_feature_value.Embed_W_Index(),
update_value + sparse_feature_value.embed_g2sum_index(),
push_value + SparsePushValue::Embed_G_Index());
_embedx_sgd_rule->update_value(
update_value + sparse_feature_value.Embedx_W_Index(),
update_value + sparse_feature_value.embedx_g2sum_index(),
push_value + SparsePushValue::Embedx_G_Index());
update_value[sparse_feature_value.UnseenDaysIndex()] = 0;
_embed_sgd_rule->UpdateValue(
update_value + sparse_feature_value.EmbedWIndex(),
update_value + sparse_feature_value.EmbedG2SumIndex(),
push_value + SparsePushValue::EmbedGIndex());
_embedx_sgd_rule->UpdateValue(
update_value + sparse_feature_value.EmbedxWIndex(),
update_value + sparse_feature_value.EmbedxG2SumIndex(),
push_value + SparsePushValue::EmbedxGIndex());
}
return 0;
}
......@@ -303,7 +248,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) {
// operation
auto show = SparsePushValue::Show(const_cast<float*>(value));
auto click = SparsePushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
......@@ -317,7 +262,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) {
}
}
float SparseAccessor::show_click_score(float show, float click) {
float SparseAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
......@@ -329,16 +274,16 @@ std::string SparseAccessor::ParseToString(const float* v, int param) {
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5];
for (int i = sparse_feature_value.embed_g2sum_index();
i < sparse_feature_value.Embedx_W_Index(); i++) {
for (int i = sparse_feature_value.EmbedG2SumIndex();
i < sparse_feature_value.EmbedxWIndex(); i++) {
os << " " << v[i];
}
auto show = sparse_feature_value.Show(const_cast<float*>(v));
auto click = sparse_feature_value.Click(const_cast<float*>(v));
auto score = show_click_score(show, click);
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > sparse_feature_value.Embedx_W_Index()) {
for (auto i = sparse_feature_value.Embedx_W_Index();
param > sparse_feature_value.EmbedxWIndex()) {
for (auto i = sparse_feature_value.EmbedxWIndex();
i < sparse_feature_value.Dim(); ++i) {
os << " " << v[i];
}
......@@ -349,9 +294,8 @@ std::string SparseAccessor::ParseToString(const float* v, int param) {
int SparseAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value(
value + sparse_feature_value.Embedx_W_Index(),
value + sparse_feature_value.embedx_g2sum_index());
_embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
value + sparse_feature_value.EmbedxG2SumIndex());
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret;
......
......@@ -44,24 +44,24 @@ class SparseAccessor : public ValueAccessor {
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); }
int SlotIndex() { return 0; }
int unseen_days_index() { return SlotIndex() + 1; }
int delta_score_index() { return unseen_days_index() + 1; }
int ShowIndex() { return delta_score_index() + 1; }
int UnseenDaysIndex() { return SlotIndex() + 1; }
int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; }
int Embed_W_Index() { return ClickIndex() + 1; }
int embed_g2sum_index() { return Embed_W_Index() + 1; }
int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; }
int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; }
int EmbedWIndex() { return ClickIndex() + 1; }
int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; }
int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; }
float& delta_score(float* val) { return val[delta_score_index()]; }
float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; }
float& EmbedW(float* val) { return val[Embed_W_Index()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& EmbedxW(float* val) { return val[Embedx_W_Index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
int embed_sgd_dim;
int embedx_dim;
......@@ -84,18 +84,18 @@ class SparseAccessor : public ValueAccessor {
static int SlotIndex() { return 0; }
static int ShowIndex() { return SparsePushValue::SlotIndex() + 1; }
static int ClickIndex() { return SparsePushValue::ShowIndex() + 1; }
static int Embed_G_Index() { return SparsePushValue::ClickIndex() + 1; }
static int Embedx_G_Index() { return SparsePushValue::Embed_G_Index() + 1; }
static int EmbedGIndex() { return SparsePushValue::ClickIndex() + 1; }
static int EmbedxGIndex() { return SparsePushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) { return val[SparsePushValue::SlotIndex()]; }
static float& Show(float* val) { return val[SparsePushValue::ShowIndex()]; }
static float& Click(float* val) {
return val[SparsePushValue::ClickIndex()];
}
static float& EmbedG(float* val) {
return val[SparsePushValue::Embed_G_Index()];
return val[SparsePushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + SparsePushValue::Embedx_G_Index();
return val + SparsePushValue::EmbedxGIndex();
}
};
......@@ -108,41 +108,21 @@ class SparseAccessor : public ValueAccessor {
static int Dim(int embedx_dim) { return 1 + embedx_dim; }
static int DimSize(size_t dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int Embed_W_Index() { return 0; }
static int Embedx_W_Index() { return 1; }
static int EmbedWIndex() { return 0; }
static int EmbedxWIndex() { return 1; }
static float& EmbedW(float* val) {
return val[SparsePullValue::Embed_W_Index()];
return val[SparsePullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + SparsePullValue::Embedx_W_Index();
return val + SparsePullValue::EmbedxWIndex();
}
};
SparseAccessor() {}
virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
virtual ~SparseAccessor() {}
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
virtual int Initialize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
// 判断该value是否保存到ssd
......@@ -186,7 +166,7 @@ class SparseAccessor : public ValueAccessor {
}
private:
// float show_click_score(float show, float click);
// float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
......@@ -197,7 +177,7 @@ class SparseAccessor : public ValueAccessor {
public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
SparseFeatureValue sparse_feature_value;
float show_click_score(float show, float click);
float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
};
......
......@@ -21,8 +21,8 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
namespace paddle {
namespace distributed {
void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto naive_param = param.naive();
learning_rate_ = naive_param.learning_rate();
......@@ -39,17 +39,16 @@ void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
}
}
void SparseNaiveSGDRule::update_value_work(float* w, float* sgd,
const float* push_value,
float scale) {
void SparseNaiveSGDRule::UpdateValueWork(float* w, float* sgd,
const float* push_value, float scale) {
for (size_t i = 0; i < _embedding_dim; ++i) {
w[i] -= learning_rate_ * push_value[i];
bound_value(w[i]);
BoundValue(w[i]);
}
}
void SparseNaiveSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
void SparseNaiveSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) {
if (zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) {
value[i] = 0;
......@@ -60,12 +59,12 @@ void SparseNaiveSGDRule::init_value_work(float* value, float* sgd,
(local_uniform_real_distribution<float>()(local_random_engine()) * 2 -
1) *
_initial_range;
bound_value(value[i]);
BoundValue(value[i]);
}
}
}
void SparseAdaGradSGDRule::load_config(
const SparseCommonSGDRuleParameter& param, size_t emb_dim) {
void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate();
......@@ -84,42 +83,42 @@ void SparseAdaGradSGDRule::load_config(
}
}
void SparseAdaGradSGDRule::update_value_work(float* w, float* sgd,
const float* grad, float scale) {
float& g2sum = sgd[g2sum_index()];
void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd,
const float* grad, float scale) {
float& g2sum = sgd[G2SumIndex()];
double add_g2sum = 0;
for (int i = 0; i < _embedding_dim; i++) {
double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
bound_value(w[i]);
BoundValue(w[i]);
add_g2sum += scaled_grad * scaled_grad;
}
g2sum += add_g2sum / _embedding_dim;
}
void SparseAdaGradSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
void SparseAdaGradSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
bound_value(value[i]);
BoundValue(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
bound_value(value[i]);
BoundValue(value[i]);
}
}
sgd[g2sum_index()] = 0;
sgd[G2SumIndex()] = 0;
}
void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate();
......@@ -138,38 +137,38 @@ void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
}
}
void StdAdaGradSGDRule::update_value_work(float* w, float* sgd,
const float* grad, float scale) {
void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
float scale) {
for (int i = 0; i < _embedding_dim; i++) {
float& g2sum = sgd[g2sum_index() + i];
float& g2sum = sgd[G2SumIndex() + i];
double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
bound_value(w[i]);
BoundValue(w[i]);
g2sum += scaled_grad * scaled_grad;
}
}
void StdAdaGradSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
void StdAdaGradSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
bound_value(value[i]);
BoundValue(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
bound_value(value[i]);
BoundValue(value[i]);
}
sgd[g2sum_index() + i] = 0;
sgd[G2SumIndex() + i] = 0;
}
}
void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto adam_param = param.adam();
learning_rate_ = adam_param.learning_rate();
......@@ -189,12 +188,12 @@ void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
}
}
void SparseAdamSGDRule::update_value_work(float* w, float* sgd,
const float* grad, float scale) {
float* gsum = sgd + gsum_index();
float* g2sum = sgd + g2sum_index();
float* beta1_pow = sgd + beta1_pow_index();
float* beta2_pow = sgd + beta2_pow_index();
void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
float scale) {
float* gsum = sgd + GSumIndex();
float* g2sum = sgd + G2SumIndex();
float* beta1_pow = sgd + Beta1PowIndex();
float* beta2_pow = sgd + Beta2PowIndex();
const float* g = grad;
float lr = learning_rate_;
......@@ -209,35 +208,35 @@ void SparseAdamSGDRule::update_value_work(float* w, float* sgd,
g2sum[i] =
_beta2_decay_rate * g2sum[i] + (1 - _beta2_decay_rate) * g[i] * g[i];
w[i] = w[i] - lr * (gsum[i] / (sqrt(g2sum[i]) + _ada_epsilon));
bound_value(w[i]);
BoundValue(w[i]);
}
// update beta_pow_decay
(*beta1_pow) *= _beta1_decay_rate;
(*beta2_pow) *= _beta2_decay_rate;
}
void SparseAdamSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
void SparseAdamSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
bound_value(value[i]);
BoundValue(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
bound_value(value[i]);
BoundValue(value[i]);
}
}
// init rule gsum and g2sum
for (int i = gsum_index(); i < beta1_pow_index(); i++) {
for (int i = GSumIndex(); i < Beta1PowIndex(); i++) {
sgd[i] = 0.0;
}
// init beta1_pow and beta2_pow
*(sgd + beta1_pow_index()) = _beta1_decay_rate;
*(sgd + beta2_pow_index()) = _beta2_decay_rate;
*(sgd + Beta1PowIndex()) = _beta1_decay_rate;
*(sgd + Beta2PowIndex()) = _beta2_decay_rate;
}
} // namespace distributed
} // namespace paddle
......@@ -28,33 +28,33 @@ class SparseValueSGDRule {
public:
SparseValueSGDRule() {}
virtual ~SparseValueSGDRule() {}
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
_name = param.name();
}
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale) = 0;
virtual void init_value_work(float* value, float* sgd, bool zero_init) = 0;
virtual size_t dim() = 0;
const std::string& get_name() const { return _name; }
void init_value(float* value, float* sgd, bool zero_init = true) {
init_value_work(value, sgd, zero_init);
virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale) = 0;
virtual void InitValueWork(float* value, float* sgd, bool zero_init) = 0;
virtual size_t Dim() = 0;
const std::string& GetName() const { return _name; }
void InitValue(float* value, float* sgd, bool zero_init = true) {
InitValueWork(value, sgd, zero_init);
}
void update_value(float* w, float* sgd, const float* push_value,
float scale = 1) {
update_value_work(w, sgd, push_value, scale);
void UpdateValue(float* w, float* sgd, const float* push_value,
float scale = 1) {
UpdateValueWork(w, sgd, push_value, scale);
}
template <class T>
void bound_value(T& w) { // NOLINT
void BoundValue(T& w) { // NOLINT
if (!(w >= _min_bound)) {
w = (T)_min_bound;
} else if (!(w <= _max_bound)) {
w = (T)_max_bound;
}
}
float& min_bound() { return _min_bound; }
float& max_bound() { return _max_bound; }
float& MinBound() { return _min_bound; }
float& MaxBound() { return _max_bound; }
protected:
float _min_bound;
......@@ -70,12 +70,12 @@ REGISTER_PSCORE_REGISTERER(SparseValueSGDRule);
class SparseNaiveSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return 0; }
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale);
virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t Dim() { return 0; }
private:
float learning_rate_;
......@@ -83,13 +83,13 @@ class SparseNaiveSGDRule : public SparseValueSGDRule {
class SparseAdaGradSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return 1; }
size_t g2sum_index() { return 0; }
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale);
virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t Dim() { return 1; }
size_t G2SumIndex() { return 0; }
private:
float learning_rate_;
......@@ -98,13 +98,13 @@ class SparseAdaGradSGDRule : public SparseValueSGDRule {
class StdAdaGradSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return _embedding_dim; }
size_t g2sum_index() { return 0; }
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale);
virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t Dim() { return _embedding_dim; }
size_t G2SumIndex() { return 0; }
private:
float learning_rate_;
......@@ -113,16 +113,16 @@ class StdAdaGradSGDRule : public SparseValueSGDRule {
class SparseAdamSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return _embedding_dim * 2 + 2; }
size_t gsum_index() { return 0; }
size_t g2sum_index() { return gsum_index() + _embedding_dim; }
size_t beta1_pow_index() { return g2sum_index() + _embedding_dim; }
size_t beta2_pow_index() { return beta1_pow_index() + 1; }
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale);
virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t Dim() { return _embedding_dim * 2 + 2; }
size_t GSumIndex() { return 0; }
size_t G2SumIndex() { return GSumIndex() + _embedding_dim; }
size_t Beta1PowIndex() { return G2SumIndex() + _embedding_dim; }
size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
protected:
float learning_rate_;
......
......@@ -103,7 +103,6 @@ int32_t Table::InitializeAccessor() {
return -1;
}
_value_accesor.reset(accessor);
// _value_accesor->SetTableInfo(_table_info);
return 0;
}
......
......@@ -162,7 +162,6 @@ class Table {
TableParameter _config;
float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor;
AccessorInfo _table_info;
AfsClient _afs_client;
};
REGISTER_PSCORE_REGISTERER(Table);
......
......@@ -18,51 +18,19 @@
namespace paddle {
namespace distributed {
int CommMergeAccessor::Initialize() { return 0; }
void CommMergeAccessor::SetTableInfo(AccessorInfo &info) {
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.fea_dim = fea_dim();
}
size_t CommMergeAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case FEA_DIM:
return fea_dim();
default:
return 0;
}
int CommMergeAccessor::Initialize() {
InitAccessorInfo();
return 0;
}
// pull value 维度
size_t CommMergeAccessor::SelectDim() { return _config.embedx_dim(); }
// pull value 各个维度的size
size_t CommMergeAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
// pull value 各维度相加总size
size_t CommMergeAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value 维度
size_t CommMergeAccessor::UpdateDim() { return _config.embedx_dim(); }
// push value 各个维度的size
size_t CommMergeAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
// push value 各维度相加总size
size_t CommMergeAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
void CommMergeAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim();
_accessor_info.select_dim = embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.fea_dim = _config.fea_dim();
}
// 判断该value 是否进行shrink
bool CommMergeAccessor::Shrink(float * /*value*/) { return false; }
......
......@@ -30,22 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {}
virtual ~CommMergeAccessor() {}
virtual int Initialize();
virtual void SetTableInfo(AccessorInfo &info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
size_t fea_dim() { return _config.fea_dim(); }
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float * /*value*/);
// 判断该value是否在save阶段dump,
......
......@@ -75,8 +75,8 @@ TEST(downpour_feature_value_accessor_test, test_shrink) {
<< acc->common_feature_value.embedx_sgd_dim << " "
<< acc->common_feature_value.Dim() << "\n";
float* value = new float[acc->Dim()];
for (auto i = 0u; i < acc->Dim(); ++i) {
float* value = new float[acc->GetAccessorInfo().dim];
for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
value[i] = i * 1.0;
}
ASSERT_TRUE(!acc->Shrink(value));
......@@ -94,8 +94,8 @@ TEST(downpour_feature_value_accessor_test, test_save) {
ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0);
float* value = new float[acc->Dim()];
for (auto i = 0u; i < acc->Dim(); ++i) {
float* value = new float[acc->GetAccessorInfo().dim];
for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
value[i] = i * 1.0;
}
......@@ -109,7 +109,7 @@ TEST(downpour_feature_value_accessor_test, test_save) {
ASSERT_TRUE(acc->Save(value, 2));
VLOG(3) << "test_save:";
for (auto i = 0u; i < acc->Dim(); ++i) {
for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
VLOG(3) << value[i];
}
}
......@@ -145,7 +145,7 @@ TEST(downpour_feature_value_accessor_test, test_update) {
ASSERT_EQ(acc->Initialize(), 0);
VLOG(3) << "dim: " << acc->common_feature_value.Dim() << "\n";
VLOG(3) << "update_dim: " << acc->GetTableInfo(UPDATE_DIM) << "\n";
VLOG(3) << "update_dim: " << acc->GetAccessorInfo().update_dim << "\n";
const int field_size = 7 + 8;
const int item_size = 10;
......@@ -162,8 +162,8 @@ TEST(downpour_feature_value_accessor_test, test_update) {
typedef const float* const_float_ptr;
const_float_ptr* grad = new const_float_ptr[item_size];
for (auto i = 0u; i < item_size; ++i) {
float* p = new float[acc->GetTableInfo(UPDATE_DIM)];
for (auto j = 0u; j < acc->GetTableInfo(UPDATE_DIM); ++j) {
float* p = new float[acc->GetAccessorInfo().update_dim];
for (auto j = 0u; j < acc->GetAccessorInfo().update_dim; ++j) {
p[j] = i;
}
grad[i] = p;
......@@ -244,21 +244,21 @@ TEST(downpour_feature_value_accessor_test, test_update) {
v.unseen_days = 0;
v.show += push_v.show;
v.click += push_v.click;
v.delta_score += acc->show_click_score(push_v.show, push_v.click);
v.delta_score += acc->ShowClickScore(push_v.show, push_v.click);
acc->_embed_sgd_rule->update_value(&v.embed_w, &v.embed_g2sum[0],
&push_v.embed_g);
acc->_embedx_sgd_rule->update_value(&v.embedx_w[0], &v.embedx_g2sum[0],
&push_v.embedx_g[0]);
acc->_embed_sgd_rule->UpdateValue(&v.embed_w, &v.embed_g2sum[0],
&push_v.embed_g);
acc->_embedx_sgd_rule->UpdateValue(&v.embedx_w[0], &v.embedx_g2sum[0],
&push_v.embedx_g[0]);
float* ptr = new float[acc->Dim()];
float* ptr = new float[acc->GetAccessorInfo().dim];
v.to_array(ptr, parameter.embedx_dim());
exp_value.push_back(ptr);
}
acc->Update(value, grad, item_size);
for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < acc->Dim(); ++j) {
for (auto j = 0u; j < acc->GetAccessorInfo().dim; ++j) {
VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " ";
ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]);
}
......@@ -273,7 +273,7 @@ TEST(downpour_feature_value_accessor_test, test_show_click_score) {
float show = 10;
float click = 6;
ASSERT_FLOAT_EQ(acc->show_click_score(show, click), 6.8);
ASSERT_FLOAT_EQ(acc->ShowClickScore(show, click), 6.8);
}
TEST(downpour_feature_value_accessor_test, test_string_related) {
......
......@@ -31,22 +31,22 @@ TEST(sparse_value_naive_sgd_test, init_and_update) {
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
rule.load_config(param, 10);
rule.LoadConfig(param, 10);
// check init_value for zero
const int kItemSize = 10;
float w[kItemSize];
float grad[kItemSize];
rule.init_value(w, w + 9, true);
rule.InitValue(w, w + 9, true);
for (auto i = 0u; i < kItemSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0);
}
// check init_value for random
rule.init_value(w, w + 9, false);
rule.InitValue(w, w + 9, false);
for (auto i = 0u; i < kItemSize; ++i) {
ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound());
ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound());
}
// check update_value for one field
......@@ -59,7 +59,7 @@ TEST(sparse_value_naive_sgd_test, init_and_update) {
float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, -0.500000,
-0.600000, -0.700000, -0.800000, -0.900000, -1.000000};
const float* ptr_grad = grad;
rule.update_value(w, w + 9, ptr_grad);
rule.UpdateValue(w, w + 9, ptr_grad);
for (auto i = 0u; i < kItemSize; ++i) {
VLOG(3) << w[i] << "\n";
......@@ -78,14 +78,14 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
adagrad_param->add_weight_bounds(-10.0);
adagrad_param->add_weight_bounds(10.0);
rule.load_config(param, 10);
rule.LoadConfig(param, 10);
// check init_value for zero
const int kValueSize = 11;
int kEmbSize = 10;
float w[kValueSize];
rule.init_value(w, w + 10, true);
rule.InitValue(w, w + 10, true);
for (auto i = 0u; i < kEmbSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0);
......@@ -93,9 +93,9 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
ASSERT_FLOAT_EQ(w[kEmbSize], 0);
// check init_value for random
rule.init_value(w, w + 10, false);
rule.InitValue(w, w + 10, false);
for (auto i = 0u; i < kEmbSize; ++i) {
ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound());
ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound());
}
ASSERT_FLOAT_EQ(w[kEmbSize], 0);
......@@ -110,7 +110,7 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
}
const float* ptr_grad = grad;
rule.update_value(w, w + 10, ptr_grad);
rule.UpdateValue(w, w + 10, ptr_grad);
float label[] = {-0.100000, -0.200000, -0.300000, -0.400000,
-0.500000, -0.600000, -0.700000, -0.800000,
-0.900000, -1.000000, 38.500000};
......@@ -140,33 +140,33 @@ TEST(downpour_sparse_adam_test, test_init_and_update) {
SparseAdamSGDRule rule;
rule.load_config(param, embed_dim);
rule.LoadConfig(param, embed_dim);
// check init_value for zero
const int rule_dim =
rule.dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam
rule.Dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam
const int value_dim = embed_dim + rule_dim; // total dims of w + rule
float* value = new float[value_dim];
rule.init_value(value, value + embed_dim, true);
for (auto i = 0u; i < rule.beta1_pow_index(); ++i) {
rule.InitValue(value, value + embed_dim, true);
for (auto i = 0u; i < rule.Beta1PowIndex(); ++i) {
ASSERT_FLOAT_EQ(value[i], 0);
}
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta1PowIndex()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta2PowIndex()), 0.999);
// check init_value for random
rule.init_value(value, value + embed_dim, false);
rule.InitValue(value, value + embed_dim, false);
for (auto i = 0u; i < embed_dim; ++i) {
ASSERT_TRUE(value[i] >= rule.min_bound() && value[i] <= rule.max_bound());
ASSERT_TRUE(value[i] >= rule.MinBound() && value[i] <= rule.MaxBound());
}
for (auto i = rule.gsum_index(); i < rule.beta1_pow_index(); ++i) {
for (auto i = rule.GSumIndex(); i < rule.Beta1PowIndex(); ++i) {
ASSERT_FLOAT_EQ(value[i + embed_dim], 0);
}
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta1PowIndex()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta2PowIndex()), 0.999);
// check update_value
rule.init_value(value, value + embed_dim, true);
rule.InitValue(value, value + embed_dim, true);
float* grad = new float[embed_dim];
for (auto i = 0u; i < embed_dim; ++i) {
grad[i] = (i + 1) * 1.0;
......@@ -181,7 +181,7 @@ TEST(downpour_sparse_adam_test, test_init_and_update) {
0.0249996781, 0.0359995365, 0.0489993691, 0.063999176,
0.0809989572, 0.0999987125, 0.809999943, 0.998001039};
rule.update_value(value, value + embed_dim, grad);
rule.UpdateValue(value, value + embed_dim, grad);
for (auto i = 0u; i < value_dim; ++i) { // check update
ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i;
......
......@@ -1668,7 +1668,7 @@ class Fleet(object):
opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items(
):
if v:
if v or k not in opt_info:
opt_info[k] = v
program._fleet_opt = opt_info
......@@ -1745,7 +1745,7 @@ class Fleet(object):
opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items(
):
if v:
if v or k not in opt_info:
opt_info[k] = v
program._fleet_opt = opt_info
# print("fleet base opt info:", id(program), program._fleet_opt)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册