未验证 提交 f4826137 编写于 作者: W wangguanqun 提交者: GitHub

Delete function in accessor and update function name in accessor and sgd (#41292)

* delete function

* fix bug

* update name

* fix bug in strategy
上级 a9d66025
...@@ -525,12 +525,12 @@ std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id, ...@@ -525,12 +525,12 @@ std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id,
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums), io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t)); sizeof(uint32_t));
keys->resize(shard_nums); keys->resize(shard_nums);
values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM)); values->resize(shard_nums * accessor->GetAccessorInfo().update_dim);
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums); sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward( io_buffer_itr.copy_and_forward(
(void *)(values->data()), // NOLINT (void *)(values->data()), // NOLINT
shard_nums * accessor->GetTableInfo(UPDATE_SIZE)); shard_nums * accessor->GetAccessorInfo().update_size);
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
...@@ -573,7 +573,7 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id, ...@@ -573,7 +573,7 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
auto kvs = ids[shard_idx]; auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx]; auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size(); size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求 // 发送RPC请求
auto *push_request = closure->request(shard_idx); auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
...@@ -581,14 +581,13 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id, ...@@ -581,14 +581,13 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
push_request->set_client_id(_client_id); push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data(); auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t); push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) { for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); push_data_ptr += value_size;
} }
PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
...@@ -603,11 +602,9 @@ std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num, ...@@ -603,11 +602,9 @@ std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num,
size_t table_id) { size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense"); auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM); auto fea_dim = accessor->GetAccessorInfo().fea_dim;
auto select_size = accessor->GetTableInfo(SELECT_SIZE);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// callback 将各shard结果,顺序填入region // callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num, request_call_num, [request_call_num, num_per_shard, regions, region_num,
...@@ -617,7 +614,7 @@ std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num, ...@@ -617,7 +614,7 @@ std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num,
size_t region_data_idx = 0; // 当前填充的region内data偏移 size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done); auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size = size_t shard_data_size =
num_per_shard * accessor->GetTableInfo(SELECT_SIZE); num_per_shard * accessor->GetAccessorInfo().select_size;
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1; ret = -1;
...@@ -681,12 +678,13 @@ std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions, ...@@ -681,12 +678,13 @@ std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
auto accessor_info = accessor->GetAccessorInfo();
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据 // 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num); std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard = uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor_info.fea_dim, request_call_num);
size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE); size_t shard_data_size = num_per_shard * accessor_info.update_size;
size_t current_region_idx = 0; size_t current_region_idx = 0;
size_t current_region_data_idx = 0; size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
...@@ -793,7 +791,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient( ...@@ -793,7 +791,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
auto value_ptr = value_ptrs[shard_idx]; auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size(); size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求 // 发送RPC请求
auto *push_request = closure->request(shard_idx); auto *push_request = closure->request(shard_idx);
...@@ -802,15 +800,14 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient( ...@@ -802,15 +800,14 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
push_request->set_client_id(_client_id); push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data(); auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t); push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) { for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); push_data_ptr += value_size;
} }
PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
...@@ -831,7 +828,7 @@ std::future<int32_t> BrpcPsClient::PushDenseRawGradient( ...@@ -831,7 +828,7 @@ std::future<int32_t> BrpcPsClient::PushDenseRawGradient(
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
uint32_t num_per_shard = uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id); closure->request(i)->set_table_id(table_id);
...@@ -910,7 +907,7 @@ std::future<int32_t> BrpcPsClient::PullSparse(float **select_values, ...@@ -910,7 +907,7 @@ std::future<int32_t> BrpcPsClient::PullSparse(float **select_values,
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE); size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) { request_call_num, [shard_sorted_kvs, value_size](void *done) {
...@@ -1023,8 +1020,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values, ...@@ -1023,8 +1020,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
} }
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE); size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) { request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0; int ret = 0;
...@@ -1147,7 +1143,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial( ...@@ -1147,7 +1143,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) { uint32_t num, void *done, int pserver_idx) {
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(UPDATE_SIZE); size_t value_size = accessor->GetAccessorInfo().update_size;
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done); DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
...@@ -1307,7 +1303,7 @@ std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id, ...@@ -1307,7 +1303,7 @@ std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id,
shard_kv_data.kv_num = 0; shard_kv_data.kv_num = 0;
continue; continue;
} }
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); uint32_t value_size = accessor->GetAccessorInfo().update_size;
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first; shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign( shard_kv_data.value_list[kv_idx].assign(
...@@ -1453,7 +1449,7 @@ void BrpcPsClient::PushSparseTaskConsume() { ...@@ -1453,7 +1449,7 @@ void BrpcPsClient::PushSparseTaskConsume() {
void sparse_local_merge(ValueAccessor *accessor, float *merge_data, void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
const float *another_data) { const float *another_data) {
size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float); size_t col_num = accessor->GetAccessorInfo().update_dim;
float *merge_data_shell[col_num]; float *merge_data_shell[col_num];
const float *another_data_shell[col_num]; const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) { for (int i = 0; i < col_num; ++i) {
...@@ -1469,7 +1465,7 @@ int BrpcPsClient::PushSparseAsyncShardMerge( ...@@ -1469,7 +1465,7 @@ int BrpcPsClient::PushSparseAsyncShardMerge(
ValueAccessor *accessor) { ValueAccessor *accessor) {
size_t merged_kv_count = 0; size_t merged_kv_count = 0;
uint64_t min_key = UINT64_MAX; uint64_t min_key = UINT64_MAX;
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); uint32_t value_size = accessor->GetAccessorInfo().update_size;
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list; thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear(); sorted_kv_list.clear();
...@@ -1575,9 +1571,8 @@ int BrpcPsClient::PushSparseAsyncShardPush( ...@@ -1575,9 +1571,8 @@ int BrpcPsClient::PushSparseAsyncShardPush(
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count), push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data(); auto *push_data = push_request->mutable_data();
int update_size = accessor->GetTableInfo(UPDATE_SIZE); int update_size = accessor->GetAccessorInfo().update_size;
push_data->resize(merged_kv_count * push_data->resize(merged_kv_count * (sizeof(uint64_t) + update_size));
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, merged_key_list.data(), memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t)); merged_kv_count * sizeof(uint64_t));
...@@ -1586,8 +1581,8 @@ int BrpcPsClient::PushSparseAsyncShardPush( ...@@ -1586,8 +1581,8 @@ int BrpcPsClient::PushSparseAsyncShardPush(
const char *task_data_ptr = merged_value_list[i].data(); const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
accessor->GetTableInfo(UPDATE_SIZE)); update_size);
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); push_data_ptr += update_size;
} }
PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
...@@ -1602,8 +1597,8 @@ std::future<int32_t> BrpcPsClient::PushDense(const Region *regions, ...@@ -1602,8 +1597,8 @@ std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto *accessor = GetTableAccessor(table_id); auto *accessor = GetTableAccessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM); int fea_dim = accessor->GetAccessorInfo().fea_dim;
int update_dim = accessor->GetTableInfo(UPDATE_DIM); int update_dim = accessor->GetAccessorInfo().update_dim;
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense"); auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer = auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse"); std::make_shared<CostTimer>("pserver_client_push_dense_parse");
...@@ -1621,13 +1616,9 @@ std::future<int32_t> BrpcPsClient::PushDense(const Region *regions, ...@@ -1621,13 +1616,9 @@ std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
auto dense_data = std::make_shared<std::vector<float>>(); auto dense_data = std::make_shared<std::vector<float>>();
auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer); auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// 将region数据拷贝到转置矩阵中 // 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num * async_task->data()->resize(num_per_shard * request_call_num * update_dim);
accessor->GetTableInfo(UPDATE_DIM));
float *data = async_task->data()->data(); float *data = async_task->data()->data();
size_t data_size = async_task->data()->size(); size_t data_size = async_task->data()->size();
uint32_t pos = 0; uint32_t pos = 0;
...@@ -1757,7 +1748,7 @@ void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, ...@@ -1757,7 +1748,7 @@ void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task,
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc"); auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer); closure->add_timer(timer);
uint32_t num_per_shard = uint32_t num_per_shard =
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
auto send_timer = auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send"); std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
......
...@@ -205,7 +205,7 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request, ...@@ -205,7 +205,7 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
} }
auto res_data = butil::get_object<std::vector<float>>(); auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) / res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
sizeof(float)); sizeof(float));
TableContext table_context; TableContext table_context;
...@@ -384,7 +384,7 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, ...@@ -384,7 +384,7 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
CostTimer timer("pserver_server_pull_sparse"); CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM); auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer; thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size); req_buffer.reserve(req_buffer_size);
......
...@@ -99,7 +99,8 @@ int32_t PsLocalClient::Initialize() { ...@@ -99,7 +99,8 @@ int32_t PsLocalClient::Initialize() {
auto* accessor = GetTableAccessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1); uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(num_per_shard); region_buffer.resize(num_per_shard);
...@@ -145,8 +146,8 @@ int32_t PsLocalClient::Initialize() { ...@@ -145,8 +146,8 @@ int32_t PsLocalClient::Initialize() {
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0); region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
...@@ -179,8 +180,8 @@ int32_t PsLocalClient::Initialize() { ...@@ -179,8 +180,8 @@ int32_t PsLocalClient::Initialize() {
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1)); region_buffer.resize(
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
size_t data_size = region_buffer.size(); size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
......
...@@ -46,27 +46,24 @@ struct DataConverter { ...@@ -46,27 +46,24 @@ struct DataConverter {
}; };
struct AccessorInfo { struct AccessorInfo {
// value维度
size_t dim; size_t dim;
// value各个维度的size
size_t size; size_t size;
size_t select_size; // pull value维度
size_t select_dim; size_t select_dim;
size_t update_size; // pull value各维度相加总size
size_t select_size;
// push value维度
size_t update_dim; size_t update_dim;
// push value各个维度的size
size_t update_size;
// value中mf动态长度部分总size大小, sparse下生效
size_t mf_size; size_t mf_size;
// value总维度,dense下生效
size_t fea_dim; size_t fea_dim;
}; };
enum InfoKey {
DIM = 0,
SIZE = 1,
SELECT_SIZE = 2,
SELECT_DIM = 3,
UPDATE_SIZE = 4,
UPDATE_DIM = 5,
MF_SIZE = 6,
FEA_DIM = 7
};
class ValueAccessor { class ValueAccessor {
public: public:
ValueAccessor() {} ValueAccessor() {}
...@@ -90,8 +87,7 @@ class ValueAccessor { ...@@ -90,8 +87,7 @@ class ValueAccessor {
} }
virtual int Initialize() = 0; virtual int Initialize() = 0;
virtual void SetTableInfo(AccessorInfo& info) = 0; virtual AccessorInfo GetAccessorInfo() { return _accessor_info; }
virtual size_t GetTableInfo(InfoKey key) = 0;
virtual bool NeedExtendMF(float* value) { return false; } virtual bool NeedExtendMF(float* value) { return false; }
virtual bool HasMF(size_t size) { return false; } virtual bool HasMF(size_t size) { return false; }
......
...@@ -220,7 +220,8 @@ int32_t CommonDenseTable::Load(const std::string& path, ...@@ -220,7 +220,8 @@ int32_t CommonDenseTable::Load(const std::string& path,
} }
size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1; size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1 // param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t dim_num_per_shard = _table_info.fea_dim / _shard_num + 1; size_t dim_num_per_shard =
_value_accesor->GetAccessorInfo().fea_dim / _shard_num + 1;
size_t start_dim_idx = dim_num_per_shard * _shard_idx; size_t start_dim_idx = dim_num_per_shard * _shard_idx;
size_t start_file_idx = start_dim_idx / dim_num_per_file; size_t start_file_idx = start_dim_idx / dim_num_per_file;
size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file; size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file;
......
...@@ -23,87 +23,35 @@ namespace distributed { ...@@ -23,87 +23,35 @@ namespace distributed {
int CtrCommonAccessor::Initialize() { int CtrCommonAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name(); name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(), _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim()); _config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->dim(); common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim(); common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim(); common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
InitAccessorInfo();
return 0; return 0;
} }
void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) { void CtrCommonAccessor::InitAccessorInfo() {
info.dim = Dim(); _accessor_info.dim = common_feature_value.Dim();
info.size = Size(); _accessor_info.size = common_feature_value.Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t CtrCommonAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t CtrCommonAccessor::Dim() { return common_feature_value.Dim(); }
size_t CtrCommonAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return common_feature_value.DimSize(dim, embedx_dim);
}
size_t CtrCommonAccessor::Size() { return common_feature_value.Size(); }
size_t CtrCommonAccessor::MFSize() {
return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) *
sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t CtrCommonAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t CtrCommonAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value
size_t CtrCommonAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim; _accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
} }
size_t CtrCommonAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool CtrCommonAccessor::Shrink(float* value) { bool CtrCommonAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
...@@ -116,9 +64,9 @@ bool CtrCommonAccessor::Shrink(float* value) { ...@@ -116,9 +64,9 @@ bool CtrCommonAccessor::Shrink(float* value) {
common_feature_value.Click(value) *= _show_click_decay_rate; common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after // shrink after
auto score = show_click_score(common_feature_value.Show(value), auto score = ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)); common_feature_value.Click(value));
auto unseen_days = common_feature_value.unseen_days(value); auto unseen_days = common_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) { if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true; return true;
} }
...@@ -141,14 +89,13 @@ bool CtrCommonAccessor::Save(float* value, int param) { ...@@ -141,14 +89,13 @@ bool CtrCommonAccessor::Save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
if (show_click_score(common_feature_value.Show(value), if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= common_feature_value.Click(value)) >= base_threshold &&
base_threshold && common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.delta_score(value) >= delta_threshold && common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.unseen_days(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
if (param == 2) { if (param == 2) {
common_feature_value.delta_score(value) = 0; common_feature_value.DeltaScore(value) = 0;
} }
return true; return true;
} else { } else {
...@@ -158,7 +105,7 @@ bool CtrCommonAccessor::Save(float* value, int param) { ...@@ -158,7 +105,7 @@ bool CtrCommonAccessor::Save(float* value, int param) {
// already decayed in shrink // already decayed in shrink
case 3: { case 3: {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
// common_feature_value.unseen_days(value)++; // common_feature_value.UnseenDays(value)++;
return true; return true;
} }
// save revert batch_model // save revert batch_model
...@@ -179,17 +126,16 @@ void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -179,17 +126,16 @@ void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
if (show_click_score(common_feature_value.Show(value), if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= common_feature_value.Click(value)) >= base_threshold &&
base_threshold && common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.delta_score(value) >= delta_threshold && common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.unseen_days(value) <= delta_keep_days) { common_feature_value.DeltaScore(value) = 0;
common_feature_value.delta_score(value) = 0;
} }
} }
return; return;
case 3: { case 3: {
common_feature_value.unseen_days(value)++; common_feature_value.UnseenDays(value)++;
} }
return; return;
default: default:
...@@ -201,17 +147,16 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) { ...@@ -201,17 +147,16 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[common_feature_value.unseen_days_index()] = 0; value[common_feature_value.UnseenDaysIndex()] = 0;
value[common_feature_value.delta_score_index()] = 0; value[common_feature_value.DeltaScoreIndex()] = 0;
value[common_feature_value.ShowIndex()] = 0; value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0; value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1; value[common_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.Embed_W_Index(), value + common_feature_value.EmbedG2SumIndex());
value + common_feature_value.embed_g2sum_index()); _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
_embedx_sgd_rule->init_value( value + common_feature_value.EmbedxG2SumIndex(),
value + common_feature_value.Embedx_W_Index(), false);
value + common_feature_value.embedx_g2sum_index(), false);
} }
return 0; return 0;
} }
...@@ -225,7 +170,7 @@ bool CtrCommonAccessor::NeedExtendMF(float* value) { ...@@ -225,7 +170,7 @@ bool CtrCommonAccessor::NeedExtendMF(float* value) {
} }
bool CtrCommonAccessor::HasMF(size_t size) { bool CtrCommonAccessor::HasMF(size_t size) {
return size > common_feature_value.embedx_g2sum_index(); return size > common_feature_value.EmbedxG2SumIndex();
} }
// from CommonFeatureValue to CtrCommonPullValue // from CommonFeatureValue to CtrCommonPullValue
...@@ -239,10 +184,10 @@ int32_t CtrCommonAccessor::Select(float** select_values, const float** values, ...@@ -239,10 +184,10 @@ int32_t CtrCommonAccessor::Select(float** select_values, const float** values,
value[common_feature_value.ShowIndex()]; value[common_feature_value.ShowIndex()];
select_value[CtrCommonPullValue::ClickIndex()] = select_value[CtrCommonPullValue::ClickIndex()] =
value[common_feature_value.ClickIndex()]; value[common_feature_value.ClickIndex()];
select_value[CtrCommonPullValue::Embed_W_Index()] = select_value[CtrCommonPullValue::EmbedWIndex()] =
value[common_feature_value.Embed_W_Index()]; value[common_feature_value.EmbedWIndex()];
memcpy(select_value + CtrCommonPullValue::Embedx_W_Index(), memcpy(select_value + CtrCommonPullValue::EmbedxWIndex(),
value + common_feature_value.Embedx_W_Index(), value + common_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -283,18 +228,18 @@ int32_t CtrCommonAccessor::Update(float** update_values, ...@@ -283,18 +228,18 @@ int32_t CtrCommonAccessor::Update(float** update_values,
update_value[common_feature_value.ShowIndex()] += push_show; update_value[common_feature_value.ShowIndex()] += push_show;
update_value[common_feature_value.ClickIndex()] += push_click; update_value[common_feature_value.ClickIndex()] += push_click;
update_value[common_feature_value.SlotIndex()] = slot; update_value[common_feature_value.SlotIndex()] = slot;
update_value[common_feature_value.delta_score_index()] += update_value[common_feature_value.DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.unseen_days_index()] = 0; update_value[common_feature_value.UnseenDaysIndex()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->UpdateValue(
update_value + common_feature_value.Embed_W_Index(), update_value + common_feature_value.EmbedWIndex(),
update_value + common_feature_value.embed_g2sum_index(), update_value + common_feature_value.EmbedG2SumIndex(),
push_value + CtrCommonPushValue::Embed_G_Index()); push_value + CtrCommonPushValue::EmbedGIndex());
_embedx_sgd_rule->update_value( _embedx_sgd_rule->UpdateValue(
update_value + common_feature_value.Embedx_W_Index(), update_value + common_feature_value.EmbedxWIndex(),
update_value + common_feature_value.embedx_g2sum_index(), update_value + common_feature_value.EmbedxG2SumIndex(),
push_value + CtrCommonPushValue::Embedx_G_Index()); push_value + CtrCommonPushValue::EmbedxGIndex());
} }
return 0; return 0;
} }
...@@ -308,7 +253,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) { ...@@ -308,7 +253,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
// operation // operation
auto show = CtrCommonPushValue::Show(const_cast<float*>(value)); auto show = CtrCommonPushValue::Show(const_cast<float*>(value));
auto click = CtrCommonPushValue::Click(const_cast<float*>(value)); auto click = CtrCommonPushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
} }
...@@ -322,7 +267,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) { ...@@ -322,7 +267,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
} }
} }
float CtrCommonAccessor::show_click_score(float show, float click) { float CtrCommonAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff(); auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff; return (show - click) * nonclk_coeff + click * click_coeff;
...@@ -334,16 +279,16 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) { ...@@ -334,16 +279,16 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5]; << v[5];
for (int i = common_feature_value.embed_g2sum_index(); for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.Embedx_W_Index(); i++) { i < common_feature_value.EmbedxWIndex(); i++) {
os << " " << v[i]; os << " " << v[i];
} }
auto show = common_feature_value.Show(const_cast<float*>(v)); auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v)); auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && if (score >= _config.embedx_threshold() &&
param > common_feature_value.Embedx_W_Index()) { param > common_feature_value.EmbedxWIndex()) {
for (auto i = common_feature_value.Embedx_W_Index(); for (auto i = common_feature_value.EmbedxWIndex();
i < common_feature_value.Dim(); ++i) { i < common_feature_value.Dim(); ++i) {
os << " " << v[i]; os << " " << v[i];
} }
...@@ -354,9 +299,8 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) { ...@@ -354,9 +299,8 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) { int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value( _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.Embedx_W_Index(), value + common_feature_value.EmbedxG2SumIndex());
value + common_feature_value.embedx_g2sum_index());
auto ret = paddle::string::str_to_float(str.data(), value); auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret; CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret; return ret;
......
...@@ -44,24 +44,24 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -44,24 +44,24 @@ class CtrCommonAccessor : public ValueAccessor {
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); } int Size() { return Dim() * sizeof(float); }
int SlotIndex() { return 0; } int SlotIndex() { return 0; }
int unseen_days_index() { return SlotIndex() + 1; } int UnseenDaysIndex() { return SlotIndex() + 1; }
int delta_score_index() { return unseen_days_index() + 1; } int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return delta_score_index() + 1; } int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; } int ClickIndex() { return ShowIndex() + 1; }
int Embed_W_Index() { return ClickIndex() + 1; } int EmbedWIndex() { return ClickIndex() + 1; }
int embed_g2sum_index() { return Embed_W_Index() + 1; } int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; } int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; }
int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; } int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; } float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& delta_score(float* val) { return val[delta_score_index()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; } float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; } float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; } float& Slot(float* val) { return val[SlotIndex()]; }
float& EmbedW(float* val) { return val[Embed_W_Index()]; } float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxW(float* val) { return val[Embedx_W_Index()]; } float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
int embed_sgd_dim; int embed_sgd_dim;
int embedx_dim; int embedx_dim;
...@@ -84,10 +84,8 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -84,10 +84,8 @@ class CtrCommonAccessor : public ValueAccessor {
static int SlotIndex() { return 0; } static int SlotIndex() { return 0; }
static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; } static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; }
static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; } static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; }
static int Embed_G_Index() { return CtrCommonPushValue::ClickIndex() + 1; } static int EmbedGIndex() { return CtrCommonPushValue::ClickIndex() + 1; }
static int Embedx_G_Index() { static int EmbedxGIndex() { return CtrCommonPushValue::EmbedGIndex() + 1; }
return CtrCommonPushValue::Embed_G_Index() + 1;
}
static float& Slot(float* val) { static float& Slot(float* val) {
return val[CtrCommonPushValue::SlotIndex()]; return val[CtrCommonPushValue::SlotIndex()];
} }
...@@ -98,10 +96,10 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -98,10 +96,10 @@ class CtrCommonAccessor : public ValueAccessor {
return val[CtrCommonPushValue::ClickIndex()]; return val[CtrCommonPushValue::ClickIndex()];
} }
static float& EmbedG(float* val) { static float& EmbedG(float* val) {
return val[CtrCommonPushValue::Embed_G_Index()]; return val[CtrCommonPushValue::EmbedGIndex()];
} }
static float* EmbedxG(float* val) { static float* EmbedxG(float* val) {
return val + CtrCommonPushValue::Embedx_G_Index(); return val + CtrCommonPushValue::EmbedxGIndex();
} }
}; };
...@@ -118,8 +116,8 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -118,8 +116,8 @@ class CtrCommonAccessor : public ValueAccessor {
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; } static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; } static int ClickIndex() { return 1; }
static int Embed_W_Index() { return 2; } static int EmbedWIndex() { return 2; }
static int Embedx_W_Index() { return 3; } static int EmbedxWIndex() { return 3; }
static float& Show(float* val) { static float& Show(float* val) {
return val[CtrCommonPullValue::ShowIndex()]; return val[CtrCommonPullValue::ShowIndex()];
} }
...@@ -127,38 +125,17 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -127,38 +125,17 @@ class CtrCommonAccessor : public ValueAccessor {
return val[CtrCommonPullValue::ClickIndex()]; return val[CtrCommonPullValue::ClickIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[CtrCommonPullValue::Embed_W_Index()]; return val[CtrCommonPullValue::EmbedWIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return val + CtrCommonPullValue::Embedx_W_Index(); return val + CtrCommonPullValue::EmbedxWIndex();
} }
}; };
CtrCommonAccessor() {} CtrCommonAccessor() {}
virtual int Initialize();
virtual ~CtrCommonAccessor() {} virtual ~CtrCommonAccessor() {}
virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info); // 初始化AccessorInfo
virtual size_t GetTableInfo(InfoKey key); virtual void InitAccessorInfo();
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool Shrink(float* value); virtual bool Shrink(float* value);
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
...@@ -202,7 +179,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -202,7 +179,7 @@ class CtrCommonAccessor : public ValueAccessor {
} }
private: private:
// float show_click_score(float show, float click); // float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule; // SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule; // SparseValueSGDRule* _embedx_sgd_rule;
...@@ -213,7 +190,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -213,7 +190,7 @@ class CtrCommonAccessor : public ValueAccessor {
public: // TODO(zhaocaibei123): it should be private, but we make it public public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test // for unit test
CtrCommonFeatureValue common_feature_value; CtrCommonFeatureValue common_feature_value;
float show_click_score(float show, float click); float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule; SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule; SparseValueSGDRule* _embedx_sgd_rule;
}; };
......
...@@ -23,89 +23,32 @@ namespace distributed { ...@@ -23,89 +23,32 @@ namespace distributed {
int DownpourCtrDoubleAccessor::Initialize() { int DownpourCtrDoubleAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name(); name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(), _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim()); _config.embedx_dim());
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold = _ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold(); _config.ctr_accessor_param().ssd_unseenday_threshold();
InitAccessorInfo();
return 0; return 0;
} }
void DownpourCtrDoubleAccessor::SetTableInfo(AccessorInfo& info) { void DownpourCtrDoubleAccessor::InitAccessorInfo() {
info.dim = Dim();
info.size = Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t DownpourCtrDoubleAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t DownpourCtrDoubleAccessor::Dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::Dim(embedx_dim);
}
size_t DownpourCtrDoubleAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::DimSize(dim, embedx_dim);
}
size_t DownpourCtrDoubleAccessor::Size() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::Size(embedx_dim); _accessor_info.dim = DownpourCtrDoubleFeatureValue::Dim(embedx_dim);
} _accessor_info.size = DownpourCtrDoubleFeatureValue::Size(embedx_dim);
size_t DownpourCtrDoubleAccessor::MFSize() { _accessor_info.select_dim = 3 + embedx_dim;
return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum _accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
} _accessor_info.update_dim = 4 + embedx_dim;
// pull value _accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
size_t DownpourCtrDoubleAccessor::SelectDim() { _accessor_info.mf_size = (embedx_dim + 1) * sizeof(float);
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t DownpourCtrDoubleAccessor::SelectDimSize(size_t dim) {
return sizeof(float);
}
size_t DownpourCtrDoubleAccessor::SelectSize() {
return SelectDim() * sizeof(float);
}
// push value
size_t DownpourCtrDoubleAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t DownpourCtrDoubleAccessor::UpdateDimSize(size_t dim) {
return sizeof(float);
}
size_t DownpourCtrDoubleAccessor::UpdateSize() {
return UpdateDim() * sizeof(float);
} }
bool DownpourCtrDoubleAccessor::Shrink(float* value) { bool DownpourCtrDoubleAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
...@@ -119,16 +62,16 @@ bool DownpourCtrDoubleAccessor::Shrink(float* value) { ...@@ -119,16 +62,16 @@ bool DownpourCtrDoubleAccessor::Shrink(float* value) {
DownpourCtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate; DownpourCtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate;
DownpourCtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate; DownpourCtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate;
// shrink after // shrink after
auto score = show_click_score(DownpourCtrDoubleFeatureValue::Show(value), auto score = ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)); DownpourCtrDoubleFeatureValue::Click(value));
auto unseen_days = DownpourCtrDoubleFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrDoubleFeatureValue::UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) { if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true; return true;
} }
return false; return false;
} }
bool DownpourCtrDoubleAccessor::save_ssd(float* value) { bool DownpourCtrDoubleAccessor::save_ssd(float* value) {
if (DownpourCtrDoubleFeatureValue::unseen_days(value) > if (DownpourCtrDoubleFeatureValue::UnseenDays(value) >
_ssd_unseenday_threshold) { _ssd_unseenday_threshold) {
return true; return true;
} }
...@@ -138,9 +81,9 @@ bool DownpourCtrDoubleAccessor::save_ssd(float* value) { ...@@ -138,9 +81,9 @@ bool DownpourCtrDoubleAccessor::save_ssd(float* value) {
// float* value, int param, double global_cache_threshold) { // float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value), // if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
// DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold // DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold
// && DownpourCtrDoubleFeatureValue::unseen_days(value) <= // && DownpourCtrDoubleFeatureValue::UnseenDays(value) <=
// delta_keep_days) { // delta_keep_days) {
// return DownpourCtrDoubleFeatureValue::Show(value) > // return DownpourCtrDoubleFeatureValue::Show(value) >
// global_cache_threshold; // global_cache_threshold;
...@@ -166,16 +109,14 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { ...@@ -166,16 +109,14 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value), if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >= DownpourCtrDoubleFeatureValue::Click(value)) >=
base_threshold && base_threshold &&
DownpourCtrDoubleFeatureValue::delta_score(value) >= DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
delta_threshold && DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
DownpourCtrDoubleFeatureValue::unseen_days(value) <=
delta_keep_days) {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
if (param == 2) { if (param == 2) {
DownpourCtrDoubleFeatureValue::delta_score(value) = 0; DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0;
} }
return true; return true;
} else { } else {
...@@ -187,7 +128,7 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { ...@@ -187,7 +128,7 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) {
// DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
// DownpourCtrDoubleFeatureValue::unseen_days(value)++; // DownpourCtrDoubleFeatureValue::UnseenDays(value)++;
return true; return true;
} }
default: default:
...@@ -204,19 +145,17 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -204,19 +145,17 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value), if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value),
DownpourCtrDoubleFeatureValue::Click(value)) >= DownpourCtrDoubleFeatureValue::Click(value)) >=
base_threshold && base_threshold &&
DownpourCtrDoubleFeatureValue::delta_score(value) >= DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
delta_threshold && DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
DownpourCtrDoubleFeatureValue::unseen_days(value) <= DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0;
delta_keep_days) {
DownpourCtrDoubleFeatureValue::delta_score(value) = 0;
} }
} }
return; return;
case 3: { case 3: {
DownpourCtrDoubleFeatureValue::unseen_days(value)++; DownpourCtrDoubleFeatureValue::UnseenDays(value)++;
} }
return; return;
default: default:
...@@ -228,17 +167,17 @@ int32_t DownpourCtrDoubleAccessor::Create(float** values, size_t num) { ...@@ -228,17 +167,17 @@ int32_t DownpourCtrDoubleAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[DownpourCtrDoubleFeatureValue::unseen_days_index()] = 0; value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[DownpourCtrDoubleFeatureValue::delta_score_index()] = 0; value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()) = 0; *(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()) = 0; *(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()) = 0;
value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1; value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->InitValue(
value + DownpourCtrDoubleFeatureValue::Embed_W_Index(), value + DownpourCtrDoubleFeatureValue::EmbedWIndex(),
value + DownpourCtrDoubleFeatureValue::embed_g2sum_index()); value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex());
_embedx_sgd_rule->init_value( _embedx_sgd_rule->InitValue(
value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
value + DownpourCtrDoubleFeatureValue::embedx_g2sum_index(), false); value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(), false);
} }
return 0; return 0;
} }
...@@ -264,10 +203,10 @@ int32_t DownpourCtrDoubleAccessor::Select(float** select_values, ...@@ -264,10 +203,10 @@ int32_t DownpourCtrDoubleAccessor::Select(float** select_values,
(float)*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()); (float)*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex());
select_value[DownpourCtrDoublePullValue::ClickIndex()] = select_value[DownpourCtrDoublePullValue::ClickIndex()] =
(float)*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()); (float)*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex());
select_value[DownpourCtrDoublePullValue::Embed_W_Index()] = select_value[DownpourCtrDoublePullValue::EmbedWIndex()] =
value[DownpourCtrDoubleFeatureValue::Embed_W_Index()]; value[DownpourCtrDoubleFeatureValue::EmbedWIndex()];
memcpy(select_value + DownpourCtrDoublePullValue::Embedx_W_Index(), memcpy(select_value + DownpourCtrDoublePullValue::EmbedxWIndex(),
value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -316,20 +255,20 @@ int32_t DownpourCtrDoubleAccessor::Update(float** update_values, ...@@ -316,20 +255,20 @@ int32_t DownpourCtrDoubleAccessor::Update(float** update_values,
*(double*)(update_value + DownpourCtrDoubleFeatureValue::ClickIndex()) += *(double*)(update_value + DownpourCtrDoubleFeatureValue::ClickIndex()) +=
(double)push_click; (double)push_click;
update_value[DownpourCtrDoubleFeatureValue::SlotIndex()] = slot; update_value[DownpourCtrDoubleFeatureValue::SlotIndex()] = slot;
update_value[DownpourCtrDoubleFeatureValue::delta_score_index()] += update_value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + //(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff(); // push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrDoubleFeatureValue::unseen_days_index()] = 0; update_value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->UpdateValue(
update_value + DownpourCtrDoubleFeatureValue::Embed_W_Index(), update_value + DownpourCtrDoubleFeatureValue::EmbedWIndex(),
update_value + DownpourCtrDoubleFeatureValue::embed_g2sum_index(), update_value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex(),
push_value + DownpourCtrDoublePushValue::Embed_G_Index(), push_show); push_value + DownpourCtrDoublePushValue::EmbedGIndex(), push_show);
_embedx_sgd_rule->update_value( _embedx_sgd_rule->UpdateValue(
update_value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), update_value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
update_value + DownpourCtrDoubleFeatureValue::embedx_g2sum_index(), update_value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(),
push_value + DownpourCtrDoublePushValue::Embedx_G_Index(), push_show); push_value + DownpourCtrDoublePushValue::EmbedxGIndex(), push_show);
} }
return 0; return 0;
} }
...@@ -341,7 +280,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) { ...@@ -341,7 +280,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) {
} else if (stage == 1) { } else if (stage == 1) {
auto show = DownpourCtrDoublePushValue::Show(const_cast<float*>(value)); auto show = DownpourCtrDoublePushValue::Show(const_cast<float*>(value));
auto click = DownpourCtrDoublePushValue::Click(const_cast<float*>(value)); auto click = DownpourCtrDoublePushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
} }
...@@ -354,7 +293,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) { ...@@ -354,7 +293,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) {
return true; return true;
} }
} }
double DownpourCtrDoubleAccessor::show_click_score(double show, double click) { double DownpourCtrDoubleAccessor::ShowClickScore(double show, double click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); // auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff(); // auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
...@@ -371,7 +310,7 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v, ...@@ -371,7 +310,7 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v,
<< v[8]; << v[8];
auto show = DownpourCtrDoubleFeatureValue::Show(const_cast<float*>(v)); auto show = DownpourCtrDoubleFeatureValue::Show(const_cast<float*>(v));
auto click = DownpourCtrDoubleFeatureValue::Click(const_cast<float*>(v)); auto click = DownpourCtrDoubleFeatureValue::Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 9) { if (score >= _config.embedx_threshold() && param_size > 9) {
os << " " << v[9]; os << " " << v[9];
for (auto i = 0; i < _config.embedx_dim(); ++i) { for (auto i = 0; i < _config.embedx_dim(); ++i) {
...@@ -383,19 +322,19 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v, ...@@ -383,19 +322,19 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v,
int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str, int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str,
float* value) { float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
float data_buff[Dim() + 2]; float data_buff[_accessor_info.dim + 2];
float* data_buff_ptr = data_buff; float* data_buff_ptr = data_buff;
_embedx_sgd_rule->init_value( _embedx_sgd_rule->InitValue(
data_buff_ptr + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxWIndex(),
data_buff_ptr + DownpourCtrDoubleFeatureValue::embedx_g2sum_index()); data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr); auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len; CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
int show_index = DownpourCtrDoubleFeatureValue::ShowIndex(); int show_index = DownpourCtrDoubleFeatureValue::ShowIndex();
int click_index = DownpourCtrDoubleFeatureValue::ClickIndex(); int click_index = DownpourCtrDoubleFeatureValue::ClickIndex();
int embed_w_index = DownpourCtrDoubleFeatureValue::Embed_W_Index(); int embed_w_index = DownpourCtrDoubleFeatureValue::EmbedWIndex();
// no slot, embedx // no slot, embedx
int value_dim = Dim(); int value_dim = _accessor_info.dim;
int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::embedx_g2sum_index(); int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex();
value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1; value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1;
// other case // other case
if (str_len == (value_dim - 1)) { if (str_len == (value_dim - 1)) {
...@@ -405,9 +344,8 @@ int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str, ...@@ -405,9 +344,8 @@ int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str,
*(double*)(value + show_index) = (double)data_buff_ptr[2]; *(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3]; *(double*)(value + click_index) = (double)data_buff_ptr[3];
// copy others // copy others
value[DownpourCtrDoubleFeatureValue::Embed_W_Index()] = data_buff_ptr[4]; value[DownpourCtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4];
value[DownpourCtrDoubleFeatureValue::embed_g2sum_index()] = value[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5];
data_buff_ptr[5];
memcpy(value + embedx_g2sum_index, data_buff_ptr + 6, memcpy(value + embedx_g2sum_index, data_buff_ptr + 6,
(embedx_dim + 1) * sizeof(float)); (embedx_dim + 1) * sizeof(float));
} else { } else {
......
...@@ -43,38 +43,38 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -43,38 +43,38 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int Size(int embedx_dim) { static int Size(int embedx_dim) {
return (Dim(embedx_dim) + 2) * sizeof(float); return (Dim(embedx_dim) + 2) * sizeof(float);
} }
static int unseen_days_index() { return 0; } static int UnseenDaysIndex() { return 0; }
static int delta_score_index() { static int DeltaScoreIndex() {
return DownpourCtrDoubleFeatureValue::unseen_days_index() + 1; return DownpourCtrDoubleFeatureValue::UnseenDaysIndex() + 1;
} }
static int ShowIndex() { static int ShowIndex() {
return DownpourCtrDoubleFeatureValue::delta_score_index() + 1; return DownpourCtrDoubleFeatureValue::DeltaScoreIndex() + 1;
} }
// show is double // show is double
static int ClickIndex() { static int ClickIndex() {
return DownpourCtrDoubleFeatureValue::ShowIndex() + 2; return DownpourCtrDoubleFeatureValue::ShowIndex() + 2;
} }
// click is double // click is double
static int Embed_W_Index() { static int EmbedWIndex() {
return DownpourCtrDoubleFeatureValue::ClickIndex() + 2; return DownpourCtrDoubleFeatureValue::ClickIndex() + 2;
} }
static int embed_g2sum_index() { static int EmbedG2SumIndex() {
return DownpourCtrDoubleFeatureValue::Embed_W_Index() + 1; return DownpourCtrDoubleFeatureValue::EmbedWIndex() + 1;
} }
static int SlotIndex() { static int SlotIndex() {
return DownpourCtrDoubleFeatureValue::embed_g2sum_index() + 1; return DownpourCtrDoubleFeatureValue::EmbedG2SumIndex() + 1;
} }
static int embedx_g2sum_index() { static int EmbedxG2SumIndex() {
return DownpourCtrDoubleFeatureValue::SlotIndex() + 1; return DownpourCtrDoubleFeatureValue::SlotIndex() + 1;
} }
static int Embedx_W_Index() { static int EmbedxWIndex() {
return DownpourCtrDoubleFeatureValue::embedx_g2sum_index() + 1; return DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex() + 1;
} }
static float& unseen_days(float* val) { static float& UnseenDays(float* val) {
return val[DownpourCtrDoubleFeatureValue::unseen_days_index()]; return val[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()];
} }
static float& delta_score(float* val) { static float& DeltaScore(float* val) {
return val[DownpourCtrDoubleFeatureValue::delta_score_index()]; return val[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()];
} }
static double& Show(float* val) { static double& Show(float* val) {
return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0]; return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0];
...@@ -86,16 +86,16 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -86,16 +86,16 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
return val[DownpourCtrDoubleFeatureValue::SlotIndex()]; return val[DownpourCtrDoubleFeatureValue::SlotIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrDoubleFeatureValue::Embed_W_Index()]; return val[DownpourCtrDoubleFeatureValue::EmbedWIndex()];
} }
static float& embed_g2sum(float* val) { static float& EmbedG2Sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::embed_g2sum_index()]; return val[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()];
} }
static float& embedx_g2sum(float* val) { static float& EmbedxG2Sum(float* val) {
return val[DownpourCtrDoubleFeatureValue::embedx_g2sum_index()]; return val[DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return (val + DownpourCtrDoubleFeatureValue::Embedx_W_Index()); return (val + DownpourCtrDoubleFeatureValue::EmbedxWIndex());
} }
}; };
struct DownpourCtrDoublePushValue { struct DownpourCtrDoublePushValue {
...@@ -116,11 +116,11 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -116,11 +116,11 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int ClickIndex() { static int ClickIndex() {
return DownpourCtrDoublePushValue::ShowIndex() + 1; return DownpourCtrDoublePushValue::ShowIndex() + 1;
} }
static int Embed_G_Index() { static int EmbedGIndex() {
return DownpourCtrDoublePushValue::ClickIndex() + 1; return DownpourCtrDoublePushValue::ClickIndex() + 1;
} }
static int Embedx_G_Index() { static int EmbedxGIndex() {
return DownpourCtrDoublePushValue::Embed_G_Index() + 1; return DownpourCtrDoublePushValue::EmbedGIndex() + 1;
} }
static float& Slot(float* val) { static float& Slot(float* val) {
return val[DownpourCtrDoublePushValue::SlotIndex()]; return val[DownpourCtrDoublePushValue::SlotIndex()];
...@@ -132,10 +132,10 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -132,10 +132,10 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
return val[DownpourCtrDoublePushValue::ClickIndex()]; return val[DownpourCtrDoublePushValue::ClickIndex()];
} }
static float& EmbedG(float* val) { static float& EmbedG(float* val) {
return val[DownpourCtrDoublePushValue::Embed_G_Index()]; return val[DownpourCtrDoublePushValue::EmbedGIndex()];
} }
static float* EmbedxG(float* val) { static float* EmbedxG(float* val) {
return val + DownpourCtrDoublePushValue::Embedx_G_Index(); return val + DownpourCtrDoublePushValue::EmbedxGIndex();
} }
}; };
struct DownpourCtrDoublePullValue { struct DownpourCtrDoublePullValue {
...@@ -150,8 +150,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -150,8 +150,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; } static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; } static int ClickIndex() { return 1; }
static int Embed_W_Index() { return 2; } static int EmbedWIndex() { return 2; }
static int Embedx_W_Index() { return 3; } static int EmbedxWIndex() { return 3; }
static float& Show(float* val) { static float& Show(float* val) {
return val[DownpourCtrDoublePullValue::ShowIndex()]; return val[DownpourCtrDoublePullValue::ShowIndex()];
} }
...@@ -159,37 +159,17 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -159,37 +159,17 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
return val[DownpourCtrDoublePullValue::ClickIndex()]; return val[DownpourCtrDoublePullValue::ClickIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrDoublePullValue::Embed_W_Index()]; return val[DownpourCtrDoublePullValue::EmbedWIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return val + DownpourCtrDoublePullValue::Embedx_W_Index(); return val + DownpourCtrDoublePullValue::EmbedxWIndex();
} }
}; };
DownpourCtrDoubleAccessor() {} DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {}
virtual int Initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info); // 初始化AccessorInfo
virtual size_t GetTableInfo(InfoKey key); virtual void InitAccessorInfo();
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool Shrink(float* value); virtual bool Shrink(float* value);
virtual bool NeedExtendMF(float* value); virtual bool NeedExtendMF(float* value);
...@@ -235,7 +215,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -235,7 +215,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embed_w) // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embed_w)
// DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embedx_w) // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embedx_w)
private: private:
double show_click_score(double show, double click); double ShowClickScore(double show, double click);
private: private:
SparseValueSGDRule* _embed_sgd_rule; SparseValueSGDRule* _embed_sgd_rule;
......
...@@ -23,91 +23,32 @@ namespace distributed { ...@@ -23,91 +23,32 @@ namespace distributed {
int DownpourCtrAccessor::Initialize() { int DownpourCtrAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name(); name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(), _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim()); _config.embedx_dim());
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold = _ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold(); _config.ctr_accessor_param().ssd_unseenday_threshold();
set_time_decay_rates(); set_time_decay_rates();
InitAccessorInfo();
return 0; return 0;
} }
void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) { void DownpourCtrAccessor::InitAccessorInfo() {
info.dim = Dim();
info.size = Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t DownpourCtrAccessor::Dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::Dim(embedx_dim);
}
size_t DownpourCtrAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::DimSize(dim, embedx_dim);
}
size_t DownpourCtrAccessor::Size() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::Size(embedx_dim); _accessor_info.dim = DownpourCtrFeatureValue::Dim(embedx_dim);
_accessor_info.size = DownpourCtrFeatureValue::Size(embedx_dim);
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size = (embedx_dim + 1) * sizeof(float);
} }
size_t DownpourCtrAccessor::MFSize() {
return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t DownpourCtrAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t DownpourCtrAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t DownpourCtrAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value
size_t DownpourCtrAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t DownpourCtrAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t DownpourCtrAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool DownpourCtrAccessor::Shrink(float* value) { bool DownpourCtrAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
...@@ -119,7 +60,7 @@ bool DownpourCtrAccessor::Shrink(float* value) { ...@@ -119,7 +60,7 @@ bool DownpourCtrAccessor::Shrink(float* value) {
auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first // time_decay first
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
if (day_diff < 0 || day_diff > delete_after_unseen_days) { if (day_diff < 0 || day_diff > delete_after_unseen_days) {
return true; return true;
...@@ -130,7 +71,7 @@ bool DownpourCtrAccessor::Shrink(float* value) { ...@@ -130,7 +71,7 @@ bool DownpourCtrAccessor::Shrink(float* value) {
DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
// shrink after // shrink after
auto score = show_click_score(show_right, click_right); auto score = ShowClickScore(show_right, click_right);
if (score < delete_threshold) { if (score < delete_threshold) {
return true; return true;
} }
...@@ -145,7 +86,7 @@ bool DownpourCtrAccessor::save_ssd(float* value) { ...@@ -145,7 +86,7 @@ bool DownpourCtrAccessor::save_ssd(float* value) {
if (_day_id == 0) { if (_day_id == 0) {
return true; return true;
} }
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
if (unseen_days == 0) { if (unseen_days == 0) {
return false; return false;
} }
...@@ -164,9 +105,9 @@ bool DownpourCtrAccessor::save_ssd(float* value) { ...@@ -164,9 +105,9 @@ bool DownpourCtrAccessor::save_ssd(float* value) {
// float* value, int param, double global_cache_threshold) { // float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); // auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
// int16_t day_diff = _day_id - unseen_days; // int16_t day_diff = _day_id - unseen_days;
// if (show_click_score(DownpourCtrFeatureValue::Show(value), // if (ShowClickScore(DownpourCtrFeatureValue::Show(value),
// DownpourCtrFeatureValue::Click(value)) >= base_threshold // DownpourCtrFeatureValue::Click(value)) >= base_threshold
// && day_diff <= delta_keep_days) { // && day_diff <= delta_keep_days) {
// return DownpourCtrFeatureValue::Show(value) > global_cache_threshold; // return DownpourCtrFeatureValue::Show(value) > global_cache_threshold;
...@@ -193,7 +134,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) { ...@@ -193,7 +134,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
auto show_right = auto show_right =
...@@ -201,12 +142,12 @@ bool DownpourCtrAccessor::Save(float* value, int param) { ...@@ -201,12 +142,12 @@ bool DownpourCtrAccessor::Save(float* value, int param) {
auto click_right = auto click_right =
DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold && if (ShowClickScore(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold && DownpourCtrFeatureValue::DeltaScore(value) >= delta_threshold &&
day_diff <= delta_keep_days) { day_diff <= delta_keep_days) {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
if (param == 2) { if (param == 2) {
DownpourCtrFeatureValue::delta_score(value) = 0; DownpourCtrFeatureValue::DeltaScore(value) = 0;
} }
return true; return true;
} else { } else {
...@@ -218,7 +159,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) { ...@@ -218,7 +159,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) {
// DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
// DownpourCtrFeatureValue::unseen_days(value)++; // DownpourCtrFeatureValue::UnseenDays(value)++;
return true; return true;
} }
default: default:
...@@ -235,23 +176,23 @@ void DownpourCtrAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -235,23 +176,23 @@ void DownpourCtrAccessor::UpdateStatAfterSave(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
auto show_right = auto show_right =
DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
auto click_right = auto click_right =
DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold && if (ShowClickScore(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold && DownpourCtrFeatureValue::DeltaScore(value) >= delta_threshold &&
day_diff <= delta_keep_days) { day_diff <= delta_keep_days) {
DownpourCtrFeatureValue::delta_score(value) = 0; DownpourCtrFeatureValue::DeltaScore(value) = 0;
} }
} }
return; return;
// case 3: // case 3:
// { // {
// DownpourCtrFeatureValue::unseen_days(value)++; // DownpourCtrFeatureValue::UnseenDays(value)++;
// } // }
// return; // return;
default: default:
...@@ -263,17 +204,17 @@ int32_t DownpourCtrAccessor::Create(float** values, size_t num) { ...@@ -263,17 +204,17 @@ int32_t DownpourCtrAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[DownpourCtrFeatureValue::unseen_days_index()] = 0; value[DownpourCtrFeatureValue::UnseenDaysIndex()] = 0;
value[DownpourCtrFeatureValue::delta_score_index()] = 0; value[DownpourCtrFeatureValue::DeltaScoreIndex()] = 0;
value[DownpourCtrFeatureValue::ShowIndex()] = 0; value[DownpourCtrFeatureValue::ShowIndex()] = 0;
value[DownpourCtrFeatureValue::ClickIndex()] = 0; value[DownpourCtrFeatureValue::ClickIndex()] = 0;
value[DownpourCtrFeatureValue::SlotIndex()] = -1; value[DownpourCtrFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->InitValue(
value + DownpourCtrFeatureValue::Embed_W_Index(), value + DownpourCtrFeatureValue::EmbedWIndex(),
value + DownpourCtrFeatureValue::embed_g2sum_index(), true); value + DownpourCtrFeatureValue::EmbedG2SumIndex(), true);
_embedx_sgd_rule->init_value( _embedx_sgd_rule->InitValue(
value + DownpourCtrFeatureValue::Embedx_W_Index(), value + DownpourCtrFeatureValue::EmbedxWIndex(),
value + DownpourCtrFeatureValue::embedx_g2sum_index()); value + DownpourCtrFeatureValue::EmbedxG2SumIndex());
} }
return 0; return 0;
} }
...@@ -289,7 +230,7 @@ bool DownpourCtrAccessor::NeedExtendMF(float* value) { ...@@ -289,7 +230,7 @@ bool DownpourCtrAccessor::NeedExtendMF(float* value) {
} }
bool DownpourCtrAccessor::HasMF(size_t size) { bool DownpourCtrAccessor::HasMF(size_t size) {
return size > DownpourCtrFeatureValue::embedx_g2sum_index(); return size > DownpourCtrFeatureValue::EmbedxG2SumIndex();
} }
// from DownpourCtrFeatureValue to DownpourCtrPullValue // from DownpourCtrFeatureValue to DownpourCtrPullValue
...@@ -303,10 +244,10 @@ int32_t DownpourCtrAccessor::Select(float** select_values, const float** values, ...@@ -303,10 +244,10 @@ int32_t DownpourCtrAccessor::Select(float** select_values, const float** values,
value[DownpourCtrFeatureValue::ShowIndex()]; value[DownpourCtrFeatureValue::ShowIndex()];
select_value[DownpourCtrPullValue::ClickIndex()] = select_value[DownpourCtrPullValue::ClickIndex()] =
value[DownpourCtrFeatureValue::ClickIndex()]; value[DownpourCtrFeatureValue::ClickIndex()];
select_value[DownpourCtrPullValue::Embed_W_Index()] = select_value[DownpourCtrPullValue::EmbedWIndex()] =
value[DownpourCtrFeatureValue::Embed_W_Index()]; value[DownpourCtrFeatureValue::EmbedWIndex()];
memcpy(select_value + DownpourCtrPullValue::Embedx_W_Index(), memcpy(select_value + DownpourCtrPullValue::EmbedxWIndex(),
value + DownpourCtrFeatureValue::Embedx_W_Index(), value + DownpourCtrFeatureValue::EmbedxWIndex(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -347,20 +288,20 @@ int32_t DownpourCtrAccessor::Update(float** update_values, ...@@ -347,20 +288,20 @@ int32_t DownpourCtrAccessor::Update(float** update_values,
update_value[DownpourCtrFeatureValue::ShowIndex()] += push_show; update_value[DownpourCtrFeatureValue::ShowIndex()] += push_show;
update_value[DownpourCtrFeatureValue::ClickIndex()] += push_click; update_value[DownpourCtrFeatureValue::ClickIndex()] += push_click;
update_value[DownpourCtrFeatureValue::SlotIndex()] = slot; update_value[DownpourCtrFeatureValue::SlotIndex()] = slot;
update_value[DownpourCtrFeatureValue::delta_score_index()] += update_value[DownpourCtrFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + //(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff(); // push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrFeatureValue::unseen_days_index()] = 0; update_value[DownpourCtrFeatureValue::UnseenDaysIndex()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->UpdateValue(
update_value + DownpourCtrFeatureValue::Embed_W_Index(), update_value + DownpourCtrFeatureValue::EmbedWIndex(),
update_value + DownpourCtrFeatureValue::embed_g2sum_index(), update_value + DownpourCtrFeatureValue::EmbedG2SumIndex(),
push_value + DownpourCtrPushValue::Embed_G_Index(), push_show); push_value + DownpourCtrPushValue::EmbedGIndex(), push_show);
_embedx_sgd_rule->update_value( _embedx_sgd_rule->UpdateValue(
update_value + DownpourCtrFeatureValue::Embedx_W_Index(), update_value + DownpourCtrFeatureValue::EmbedxWIndex(),
update_value + DownpourCtrFeatureValue::embedx_g2sum_index(), update_value + DownpourCtrFeatureValue::EmbedxG2SumIndex(),
push_value + DownpourCtrPushValue::Embedx_G_Index(), push_show); push_value + DownpourCtrPushValue::EmbedxGIndex(), push_show);
} }
return 0; return 0;
} }
...@@ -373,7 +314,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) { ...@@ -373,7 +314,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) {
} else if (stage == 1) { } else if (stage == 1) {
auto show = DownpourCtrPushValue::Show(const_cast<float*>(value)); auto show = DownpourCtrPushValue::Show(const_cast<float*>(value));
auto click = DownpourCtrPushValue::Click(const_cast<float*>(value)); auto click = DownpourCtrPushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
} }
...@@ -387,7 +328,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) { ...@@ -387,7 +328,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) {
} }
} }
float DownpourCtrAccessor::show_click_score(float show, float click) { float DownpourCtrAccessor::ShowClickScore(float show, float click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); // auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff(); // auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
...@@ -403,7 +344,7 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) { ...@@ -403,7 +344,7 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) {
<< v[5] << " " << v[6]; << v[5] << " " << v[6];
auto show = DownpourCtrFeatureValue::Show(const_cast<float*>(v)); auto show = DownpourCtrFeatureValue::Show(const_cast<float*>(v));
auto click = DownpourCtrFeatureValue::Click(const_cast<float*>(v)); auto click = DownpourCtrFeatureValue::Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 7) { if (score >= _config.embedx_threshold() && param_size > 7) {
os << " " << v[7]; os << " " << v[7];
for (auto i = 0; i < _config.embedx_dim(); ++i) { for (auto i = 0; i < _config.embedx_dim(); ++i) {
...@@ -415,18 +356,18 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) { ...@@ -415,18 +356,18 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) {
int DownpourCtrAccessor::ParseFromString(const std::string& str, float* value) { int DownpourCtrAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
float data_buff[Dim()]; float data_buff[_accessor_info.dim];
float* data_buff_ptr = data_buff; float* data_buff_ptr = data_buff;
_embedx_sgd_rule->init_value( _embedx_sgd_rule->InitValue(
data_buff_ptr + DownpourCtrFeatureValue::Embedx_W_Index(), data_buff_ptr + DownpourCtrFeatureValue::EmbedxWIndex(),
data_buff_ptr + DownpourCtrFeatureValue::embedx_g2sum_index()); data_buff_ptr + DownpourCtrFeatureValue::EmbedxG2SumIndex());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr); auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len; CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
// no slot, embedx // no slot, embedx
int value_dim = Dim(); int value_dim = _accessor_info.dim;
int embedx_g2sum_index = DownpourCtrFeatureValue::embedx_g2sum_index(); int embedx_g2sum_index = DownpourCtrFeatureValue::EmbedxG2SumIndex();
value[DownpourCtrFeatureValue::SlotIndex()] = -1; value[DownpourCtrFeatureValue::SlotIndex()] = -1;
// other case // other case
if (str_len == (value_dim - 1)) { if (str_len == (value_dim - 1)) {
...@@ -459,25 +400,25 @@ void DownpourCtrAccessor::update_time_decay(float* value, ...@@ -459,25 +400,25 @@ void DownpourCtrAccessor::update_time_decay(float* value,
if (_day_id == 0) { if (_day_id == 0) {
return; return;
} }
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
if (unseen_days == 0) { if (unseen_days == 0) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id; DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
return; return;
} }
// for the origin load (unseenday = 0 -15) // for the origin load (unseenday = 0 -15)
if (unseen_days < _config.ctr_accessor_param().delete_after_unseen_days()) { if (unseen_days < _config.ctr_accessor_param().delete_after_unseen_days()) {
// pull // pull
if (is_update_seen_day) { if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id; DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
return; return;
// save 舍弃原始的unseenday,都变为上一天出现,保证show/click不被重复decay // save 舍弃原始的unseenday,都变为上一天出现,保证show/click不被重复decay
} else { } else {
DownpourCtrFeatureValue::unseen_days(value) = _day_id - 1; DownpourCtrFeatureValue::UnseenDays(value) = _day_id - 1;
} }
} }
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
if (day_diff < 0) { if (day_diff < 0) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id; DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
return; return;
} }
if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) { if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) {
...@@ -486,7 +427,7 @@ void DownpourCtrAccessor::update_time_decay(float* value, ...@@ -486,7 +427,7 @@ void DownpourCtrAccessor::update_time_decay(float* value,
DownpourCtrFeatureValue::Show(value) *= _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) *= _time_decay_rates[day_diff];
DownpourCtrFeatureValue::Click(value) *= _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) *= _time_decay_rates[day_diff];
if (is_update_seen_day) { if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id; DownpourCtrFeatureValue::UnseenDays(value) = _day_id;
} }
} }
......
...@@ -45,34 +45,34 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -45,34 +45,34 @@ class DownpourCtrAccessor : public ValueAccessor {
static int Dim(int embedx_dim) { return 8 + embedx_dim; } static int Dim(int embedx_dim) { return 8 + embedx_dim; }
static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int unseen_days_index() { return 0; } static int UnseenDaysIndex() { return 0; }
static int delta_score_index() { static int DeltaScoreIndex() {
return DownpourCtrFeatureValue::unseen_days_index() + 1; return DownpourCtrFeatureValue::UnseenDaysIndex() + 1;
} }
static int ShowIndex() { static int ShowIndex() {
return DownpourCtrFeatureValue::delta_score_index() + 1; return DownpourCtrFeatureValue::DeltaScoreIndex() + 1;
} }
static int ClickIndex() { return DownpourCtrFeatureValue::ShowIndex() + 1; } static int ClickIndex() { return DownpourCtrFeatureValue::ShowIndex() + 1; }
static int Embed_W_Index() { static int EmbedWIndex() {
return DownpourCtrFeatureValue::ClickIndex() + 1; return DownpourCtrFeatureValue::ClickIndex() + 1;
} }
static int embed_g2sum_index() { static int EmbedG2SumIndex() {
return DownpourCtrFeatureValue::Embed_W_Index() + 1; return DownpourCtrFeatureValue::EmbedWIndex() + 1;
} }
static int SlotIndex() { static int SlotIndex() {
return DownpourCtrFeatureValue::embed_g2sum_index() + 1; return DownpourCtrFeatureValue::EmbedG2SumIndex() + 1;
} }
static int embedx_g2sum_index() { static int EmbedxG2SumIndex() {
return DownpourCtrFeatureValue::SlotIndex() + 1; return DownpourCtrFeatureValue::SlotIndex() + 1;
} }
static int Embedx_W_Index() { static int EmbedxWIndex() {
return DownpourCtrFeatureValue::embedx_g2sum_index() + 1; return DownpourCtrFeatureValue::EmbedxG2SumIndex() + 1;
} }
static float& unseen_days(float* val) { static float& UnseenDays(float* val) {
return val[DownpourCtrFeatureValue::unseen_days_index()]; return val[DownpourCtrFeatureValue::UnseenDaysIndex()];
} }
static float& delta_score(float* val) { static float& DeltaScore(float* val) {
return val[DownpourCtrFeatureValue::delta_score_index()]; return val[DownpourCtrFeatureValue::DeltaScoreIndex()];
} }
static float& Show(float* val) { static float& Show(float* val) {
return val[DownpourCtrFeatureValue::ShowIndex()]; return val[DownpourCtrFeatureValue::ShowIndex()];
...@@ -84,16 +84,16 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -84,16 +84,16 @@ class DownpourCtrAccessor : public ValueAccessor {
return val[DownpourCtrFeatureValue::SlotIndex()]; return val[DownpourCtrFeatureValue::SlotIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrFeatureValue::Embed_W_Index()]; return val[DownpourCtrFeatureValue::EmbedWIndex()];
} }
static float& embed_g2sum(float* val) { static float& EmbedG2Sum(float* val) {
return val[DownpourCtrFeatureValue::embed_g2sum_index()]; return val[DownpourCtrFeatureValue::EmbedG2SumIndex()];
} }
static float& embedx_g2sum(float* val) { static float& EmbedxG2Sum(float* val) {
return val[DownpourCtrFeatureValue::embedx_g2sum_index()]; return val[DownpourCtrFeatureValue::EmbedxG2SumIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return (val + DownpourCtrFeatureValue::Embedx_W_Index()); return (val + DownpourCtrFeatureValue::EmbedxWIndex());
} }
}; };
...@@ -113,11 +113,9 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -113,11 +113,9 @@ class DownpourCtrAccessor : public ValueAccessor {
static int SlotIndex() { return 0; } static int SlotIndex() { return 0; }
static int ShowIndex() { return DownpourCtrPushValue::SlotIndex() + 1; } static int ShowIndex() { return DownpourCtrPushValue::SlotIndex() + 1; }
static int ClickIndex() { return DownpourCtrPushValue::ShowIndex() + 1; } static int ClickIndex() { return DownpourCtrPushValue::ShowIndex() + 1; }
static int Embed_G_Index() { static int EmbedGIndex() { return DownpourCtrPushValue::ClickIndex() + 1; }
return DownpourCtrPushValue::ClickIndex() + 1; static int EmbedxGIndex() {
} return DownpourCtrPushValue::EmbedGIndex() + 1;
static int Embedx_G_Index() {
return DownpourCtrPushValue::Embed_G_Index() + 1;
} }
static float& Slot(float* val) { return val[0]; } static float& Slot(float* val) { return val[0]; }
static float& Show(float* val) { return val[1]; } static float& Show(float* val) { return val[1]; }
...@@ -139,8 +137,8 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -139,8 +137,8 @@ class DownpourCtrAccessor : public ValueAccessor {
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; } static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; } static int ClickIndex() { return 1; }
static int Embed_W_Index() { return 2; } static int EmbedWIndex() { return 2; }
static int Embedx_W_Index() { return 3; } static int EmbedxWIndex() { return 3; }
static float& Show(float* val) { static float& Show(float* val) {
return val[DownpourCtrPullValue::ShowIndex()]; return val[DownpourCtrPullValue::ShowIndex()];
} }
...@@ -148,38 +146,18 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -148,38 +146,18 @@ class DownpourCtrAccessor : public ValueAccessor {
return val[DownpourCtrPullValue::ClickIndex()]; return val[DownpourCtrPullValue::ClickIndex()];
} }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[DownpourCtrPullValue::Embed_W_Index()]; return val[DownpourCtrPullValue::EmbedWIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return val + DownpourCtrPullValue::Embedx_W_Index(); return val + DownpourCtrPullValue::EmbedxWIndex();
} }
}; };
DownpourCtrAccessor() {} DownpourCtrAccessor() {}
virtual ~DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {}
virtual int Initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info); // 初始化AccessorInfo
virtual size_t GetTableInfo(InfoKey key); virtual void InitAccessorInfo();
// value维度
size_t Dim();
// value各个维度的size
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool Shrink(float* value); virtual bool Shrink(float* value);
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
...@@ -219,7 +197,7 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -219,7 +197,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual float GetField(float* value, const std::string& name) override { virtual float GetField(float* value, const std::string& name) override {
CHECK(name == "show"); CHECK(name == "show");
if (name == "show") { if (name == "show") {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value);
int16_t day_diff = _day_id - unseen_days; int16_t day_diff = _day_id - unseen_days;
auto show_right = auto show_right =
DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff];
...@@ -238,7 +216,7 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -238,7 +216,7 @@ class DownpourCtrAccessor : public ValueAccessor {
bool test_func() { return false; } bool test_func() { return false; }
private: private:
float show_click_score(float show, float click); float ShowClickScore(float show, float click);
void set_time_decay_rates(); void set_time_decay_rates();
private: private:
......
...@@ -89,7 +89,7 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -89,7 +89,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size = size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().size / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
...@@ -174,7 +174,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, ...@@ -174,7 +174,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size = size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().size / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
...@@ -415,10 +415,12 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -415,10 +415,12 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); const size_t value_size =
size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t select_value_size = size_t select_value_size =
_value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().select_size / sizeof(float);
// std::atomic<uint32_t> missed_keys{0}; // std::atomic<uint32_t> missed_keys{0};
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
...@@ -482,8 +484,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -482,8 +484,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
int32_t MemorySparseTable::PullSparsePtr(char** pull_values, int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) { const uint64_t* keys, size_t num) {
CostTimer timer("pscore_sparse_select_all"); CostTimer timer("pscore_sparse_select_all");
size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
...@@ -541,10 +544,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -541,10 +544,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float); const size_t value_col =
size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_col =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t update_value_col = size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().update_size / sizeof(float);
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
...@@ -619,10 +624,11 @@ int32_t MemorySparseTable::_PushSparse(const uint64_t* keys, ...@@ -619,10 +624,11 @@ int32_t MemorySparseTable::_PushSparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float); size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); size_t mf_value_col =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t update_value_col = size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float); _value_accesor->GetAccessorInfo().update_size / sizeof(float);
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
......
...@@ -23,87 +23,35 @@ namespace distributed { ...@@ -23,87 +23,35 @@ namespace distributed {
int SparseAccessor::Initialize() { int SparseAccessor::Initialize() {
auto name = _config.embed_sgd_param().name(); auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name(); name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(), _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim()); _config.embedx_dim());
sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->dim(); sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
sparse_feature_value.embedx_dim = _config.embedx_dim(); sparse_feature_value.embedx_dim = _config.embedx_dim();
sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim(); sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
InitAccessorInfo();
return 0; return 0;
} }
void SparseAccessor::SetTableInfo(AccessorInfo& info) { void SparseAccessor::InitAccessorInfo() {
info.dim = Dim(); _accessor_info.dim = sparse_feature_value.Dim();
info.size = Size(); _accessor_info.size = sparse_feature_value.Size();
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.mf_size = MFSize();
}
size_t SparseAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return Dim();
case SIZE:
return Size();
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case MF_SIZE:
return MFSize();
default:
return 0;
}
return 0;
}
size_t SparseAccessor::Dim() { return sparse_feature_value.Dim(); }
size_t SparseAccessor::DimSize(size_t dim) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return sparse_feature_value.DimSize(dim, embedx_dim); _accessor_info.select_dim = 1 + embedx_dim;
} _accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
;
size_t SparseAccessor::Size() { return sparse_feature_value.Size(); } _accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
size_t SparseAccessor::MFSize() { _accessor_info.mf_size =
return (_config.embedx_dim() + sparse_feature_value.embedx_sgd_dim) * (embedx_dim + sparse_feature_value.embedx_sgd_dim) * sizeof(float);
sizeof(float); // embedx embedx_g2sum
} }
// pull value
size_t SparseAccessor::SelectDim() {
auto embedx_dim = _config.embedx_dim();
return 1 + embedx_dim;
}
size_t SparseAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
size_t SparseAccessor::SelectSize() { return SelectDim() * sizeof(float); }
// push value
size_t SparseAccessor::UpdateDim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t SparseAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
size_t SparseAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
bool SparseAccessor::Shrink(float* value) { bool SparseAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
...@@ -116,9 +64,9 @@ bool SparseAccessor::Shrink(float* value) { ...@@ -116,9 +64,9 @@ bool SparseAccessor::Shrink(float* value) {
sparse_feature_value.Click(value) *= _show_click_decay_rate; sparse_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after // shrink after
auto score = show_click_score(sparse_feature_value.Show(value), auto score = ShowClickScore(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)); sparse_feature_value.Click(value));
auto unseen_days = sparse_feature_value.unseen_days(value); auto unseen_days = sparse_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) { if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true; return true;
} }
...@@ -141,14 +89,13 @@ bool SparseAccessor::Save(float* value, int param) { ...@@ -141,14 +89,13 @@ bool SparseAccessor::Save(float* value, int param) {
case 1: case 1:
// save xbox base // save xbox base
case 2: { case 2: {
if (show_click_score(sparse_feature_value.Show(value), if (ShowClickScore(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)) >= sparse_feature_value.Click(value)) >= base_threshold &&
base_threshold && sparse_feature_value.DeltaScore(value) >= delta_threshold &&
sparse_feature_value.delta_score(value) >= delta_threshold && sparse_feature_value.UnseenDays(value) <= delta_keep_days) {
sparse_feature_value.unseen_days(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
if (param == 2) { if (param == 2) {
sparse_feature_value.delta_score(value) = 0; sparse_feature_value.DeltaScore(value) = 0;
} }
return true; return true;
} else { } else {
...@@ -158,7 +105,7 @@ bool SparseAccessor::Save(float* value, int param) { ...@@ -158,7 +105,7 @@ bool SparseAccessor::Save(float* value, int param) {
// already decayed in shrink // already decayed in shrink
case 3: { case 3: {
// do this after save, because it must not be modified when retry // do this after save, because it must not be modified when retry
// sparse_feature_value.unseen_days(value)++; // sparse_feature_value.UnseenDays(value)++;
return true; return true;
} }
// save revert batch_model // save revert batch_model
...@@ -179,17 +126,16 @@ void SparseAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -179,17 +126,16 @@ void SparseAccessor::UpdateStatAfterSave(float* value, int param) {
} }
switch (param) { switch (param) {
case 1: { case 1: {
if (show_click_score(sparse_feature_value.Show(value), if (ShowClickScore(sparse_feature_value.Show(value),
sparse_feature_value.Click(value)) >= sparse_feature_value.Click(value)) >= base_threshold &&
base_threshold && sparse_feature_value.DeltaScore(value) >= delta_threshold &&
sparse_feature_value.delta_score(value) >= delta_threshold && sparse_feature_value.UnseenDays(value) <= delta_keep_days) {
sparse_feature_value.unseen_days(value) <= delta_keep_days) { sparse_feature_value.DeltaScore(value) = 0;
sparse_feature_value.delta_score(value) = 0;
} }
} }
return; return;
case 3: { case 3: {
sparse_feature_value.unseen_days(value)++; sparse_feature_value.UnseenDays(value)++;
} }
return; return;
default: default:
...@@ -201,17 +147,16 @@ int32_t SparseAccessor::Create(float** values, size_t num) { ...@@ -201,17 +147,16 @@ int32_t SparseAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item]; float* value = values[value_item];
value[sparse_feature_value.unseen_days_index()] = 0; value[sparse_feature_value.UnseenDaysIndex()] = 0;
value[sparse_feature_value.delta_score_index()] = 0; value[sparse_feature_value.DeltaScoreIndex()] = 0;
value[sparse_feature_value.ShowIndex()] = 0; value[sparse_feature_value.ShowIndex()] = 0;
value[sparse_feature_value.ClickIndex()] = 0; value[sparse_feature_value.ClickIndex()] = 0;
value[sparse_feature_value.SlotIndex()] = -1; value[sparse_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->init_value( _embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(),
value + sparse_feature_value.Embed_W_Index(), value + sparse_feature_value.EmbedG2SumIndex());
value + sparse_feature_value.embed_g2sum_index()); _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
_embedx_sgd_rule->init_value( value + sparse_feature_value.EmbedxG2SumIndex(),
value + sparse_feature_value.Embedx_W_Index(), false);
value + sparse_feature_value.embedx_g2sum_index(), false);
} }
return 0; return 0;
} }
...@@ -225,7 +170,7 @@ bool SparseAccessor::NeedExtendMF(float* value) { ...@@ -225,7 +170,7 @@ bool SparseAccessor::NeedExtendMF(float* value) {
} }
bool SparseAccessor::HasMF(size_t size) { bool SparseAccessor::HasMF(size_t size) {
return size > sparse_feature_value.embedx_g2sum_index(); return size > sparse_feature_value.EmbedxG2SumIndex();
} }
// from SparseFeatureValue to SparsePullValue // from SparseFeatureValue to SparsePullValue
...@@ -235,10 +180,10 @@ int32_t SparseAccessor::Select(float** select_values, const float** values, ...@@ -235,10 +180,10 @@ int32_t SparseAccessor::Select(float** select_values, const float** values,
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item]; float* select_value = select_values[value_item];
const float* value = values[value_item]; const float* value = values[value_item];
select_value[SparsePullValue::Embed_W_Index()] = select_value[SparsePullValue::EmbedWIndex()] =
value[sparse_feature_value.Embed_W_Index()]; value[sparse_feature_value.EmbedWIndex()];
memcpy(select_value + SparsePullValue::Embedx_W_Index(), memcpy(select_value + SparsePullValue::EmbedxWIndex(),
value + sparse_feature_value.Embedx_W_Index(), value + sparse_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float)); embedx_dim * sizeof(float));
} }
return 0; return 0;
...@@ -278,18 +223,18 @@ int32_t SparseAccessor::Update(float** update_values, const float** push_values, ...@@ -278,18 +223,18 @@ int32_t SparseAccessor::Update(float** update_values, const float** push_values,
update_value[sparse_feature_value.ShowIndex()] += push_show; update_value[sparse_feature_value.ShowIndex()] += push_show;
update_value[sparse_feature_value.ClickIndex()] += push_click; update_value[sparse_feature_value.ClickIndex()] += push_click;
update_value[sparse_feature_value.SlotIndex()] = slot; update_value[sparse_feature_value.SlotIndex()] = slot;
update_value[sparse_feature_value.delta_score_index()] += update_value[sparse_feature_value.DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
update_value[sparse_feature_value.unseen_days_index()] = 0; update_value[sparse_feature_value.UnseenDaysIndex()] = 0;
_embed_sgd_rule->update_value( _embed_sgd_rule->UpdateValue(
update_value + sparse_feature_value.Embed_W_Index(), update_value + sparse_feature_value.EmbedWIndex(),
update_value + sparse_feature_value.embed_g2sum_index(), update_value + sparse_feature_value.EmbedG2SumIndex(),
push_value + SparsePushValue::Embed_G_Index()); push_value + SparsePushValue::EmbedGIndex());
_embedx_sgd_rule->update_value( _embedx_sgd_rule->UpdateValue(
update_value + sparse_feature_value.Embedx_W_Index(), update_value + sparse_feature_value.EmbedxWIndex(),
update_value + sparse_feature_value.embedx_g2sum_index(), update_value + sparse_feature_value.EmbedxG2SumIndex(),
push_value + SparsePushValue::Embedx_G_Index()); push_value + SparsePushValue::EmbedxGIndex());
} }
return 0; return 0;
} }
...@@ -303,7 +248,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) { ...@@ -303,7 +248,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) {
// operation // operation
auto show = SparsePushValue::Show(const_cast<float*>(value)); auto show = SparsePushValue::Show(const_cast<float*>(value));
auto click = SparsePushValue::Click(const_cast<float*>(value)); auto click = SparsePushValue::Click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
} }
...@@ -317,7 +262,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) { ...@@ -317,7 +262,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) {
} }
} }
float SparseAccessor::show_click_score(float show, float click) { float SparseAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff(); auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff; return (show - click) * nonclk_coeff + click * click_coeff;
...@@ -329,16 +274,16 @@ std::string SparseAccessor::ParseToString(const float* v, int param) { ...@@ -329,16 +274,16 @@ std::string SparseAccessor::ParseToString(const float* v, int param) {
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5]; << v[5];
for (int i = sparse_feature_value.embed_g2sum_index(); for (int i = sparse_feature_value.EmbedG2SumIndex();
i < sparse_feature_value.Embedx_W_Index(); i++) { i < sparse_feature_value.EmbedxWIndex(); i++) {
os << " " << v[i]; os << " " << v[i];
} }
auto show = sparse_feature_value.Show(const_cast<float*>(v)); auto show = sparse_feature_value.Show(const_cast<float*>(v));
auto click = sparse_feature_value.Click(const_cast<float*>(v)); auto click = sparse_feature_value.Click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && if (score >= _config.embedx_threshold() &&
param > sparse_feature_value.Embedx_W_Index()) { param > sparse_feature_value.EmbedxWIndex()) {
for (auto i = sparse_feature_value.Embedx_W_Index(); for (auto i = sparse_feature_value.EmbedxWIndex();
i < sparse_feature_value.Dim(); ++i) { i < sparse_feature_value.Dim(); ++i) {
os << " " << v[i]; os << " " << v[i];
} }
...@@ -349,9 +294,8 @@ std::string SparseAccessor::ParseToString(const float* v, int param) { ...@@ -349,9 +294,8 @@ std::string SparseAccessor::ParseToString(const float* v, int param) {
int SparseAccessor::ParseFromString(const std::string& str, float* value) { int SparseAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value( _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
value + sparse_feature_value.Embedx_W_Index(), value + sparse_feature_value.EmbedxG2SumIndex());
value + sparse_feature_value.embedx_g2sum_index());
auto ret = paddle::string::str_to_float(str.data(), value); auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret; CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret; return ret;
......
...@@ -44,24 +44,24 @@ class SparseAccessor : public ValueAccessor { ...@@ -44,24 +44,24 @@ class SparseAccessor : public ValueAccessor {
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); } int Size() { return Dim() * sizeof(float); }
int SlotIndex() { return 0; } int SlotIndex() { return 0; }
int unseen_days_index() { return SlotIndex() + 1; } int UnseenDaysIndex() { return SlotIndex() + 1; }
int delta_score_index() { return unseen_days_index() + 1; } int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return delta_score_index() + 1; } int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; } int ClickIndex() { return ShowIndex() + 1; }
int Embed_W_Index() { return ClickIndex() + 1; } int EmbedWIndex() { return ClickIndex() + 1; }
int embed_g2sum_index() { return Embed_W_Index() + 1; } int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; } int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; }
int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; } int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; } float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& delta_score(float* val) { return val[delta_score_index()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; } float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; } float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; } float& Slot(float* val) { return val[SlotIndex()]; }
float& EmbedW(float* val) { return val[Embed_W_Index()]; } float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxW(float* val) { return val[Embedx_W_Index()]; } float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
int embed_sgd_dim; int embed_sgd_dim;
int embedx_dim; int embedx_dim;
...@@ -84,18 +84,18 @@ class SparseAccessor : public ValueAccessor { ...@@ -84,18 +84,18 @@ class SparseAccessor : public ValueAccessor {
static int SlotIndex() { return 0; } static int SlotIndex() { return 0; }
static int ShowIndex() { return SparsePushValue::SlotIndex() + 1; } static int ShowIndex() { return SparsePushValue::SlotIndex() + 1; }
static int ClickIndex() { return SparsePushValue::ShowIndex() + 1; } static int ClickIndex() { return SparsePushValue::ShowIndex() + 1; }
static int Embed_G_Index() { return SparsePushValue::ClickIndex() + 1; } static int EmbedGIndex() { return SparsePushValue::ClickIndex() + 1; }
static int Embedx_G_Index() { return SparsePushValue::Embed_G_Index() + 1; } static int EmbedxGIndex() { return SparsePushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) { return val[SparsePushValue::SlotIndex()]; } static float& Slot(float* val) { return val[SparsePushValue::SlotIndex()]; }
static float& Show(float* val) { return val[SparsePushValue::ShowIndex()]; } static float& Show(float* val) { return val[SparsePushValue::ShowIndex()]; }
static float& Click(float* val) { static float& Click(float* val) {
return val[SparsePushValue::ClickIndex()]; return val[SparsePushValue::ClickIndex()];
} }
static float& EmbedG(float* val) { static float& EmbedG(float* val) {
return val[SparsePushValue::Embed_G_Index()]; return val[SparsePushValue::EmbedGIndex()];
} }
static float* EmbedxG(float* val) { static float* EmbedxG(float* val) {
return val + SparsePushValue::Embedx_G_Index(); return val + SparsePushValue::EmbedxGIndex();
} }
}; };
...@@ -108,41 +108,21 @@ class SparseAccessor : public ValueAccessor { ...@@ -108,41 +108,21 @@ class SparseAccessor : public ValueAccessor {
static int Dim(int embedx_dim) { return 1 + embedx_dim; } static int Dim(int embedx_dim) { return 1 + embedx_dim; }
static int DimSize(size_t dim) { return sizeof(float); } static int DimSize(size_t dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int Embed_W_Index() { return 0; } static int EmbedWIndex() { return 0; }
static int Embedx_W_Index() { return 1; } static int EmbedxWIndex() { return 1; }
static float& EmbedW(float* val) { static float& EmbedW(float* val) {
return val[SparsePullValue::Embed_W_Index()]; return val[SparsePullValue::EmbedWIndex()];
} }
static float* EmbedxW(float* val) { static float* EmbedxW(float* val) {
return val + SparsePullValue::Embedx_W_Index(); return val + SparsePullValue::EmbedxWIndex();
} }
}; };
SparseAccessor() {} SparseAccessor() {}
virtual int Initialize();
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
virtual ~SparseAccessor() {} virtual ~SparseAccessor() {}
// value维度 virtual int Initialize();
size_t Dim(); // 初始化AccessorInfo
// value各个维度的size virtual void InitAccessorInfo();
size_t DimSize(size_t dim);
// value各维度相加总size
size_t Size();
// value中mf动态长度部分总size大小, sparse下生效
size_t MFSize();
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool Shrink(float* value); virtual bool Shrink(float* value);
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
...@@ -186,7 +166,7 @@ class SparseAccessor : public ValueAccessor { ...@@ -186,7 +166,7 @@ class SparseAccessor : public ValueAccessor {
} }
private: private:
// float show_click_score(float show, float click); // float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule; // SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule; // SparseValueSGDRule* _embedx_sgd_rule;
...@@ -197,7 +177,7 @@ class SparseAccessor : public ValueAccessor { ...@@ -197,7 +177,7 @@ class SparseAccessor : public ValueAccessor {
public: // TODO(zhaocaibei123): it should be private, but we make it public public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test // for unit test
SparseFeatureValue sparse_feature_value; SparseFeatureValue sparse_feature_value;
float show_click_score(float show, float click); float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule; SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule; SparseValueSGDRule* _embedx_sgd_rule;
}; };
......
...@@ -21,8 +21,8 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient"); ...@@ -21,8 +21,8 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param, void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto naive_param = param.naive(); auto naive_param = param.naive();
learning_rate_ = naive_param.learning_rate(); learning_rate_ = naive_param.learning_rate();
...@@ -39,17 +39,16 @@ void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param, ...@@ -39,17 +39,16 @@ void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
} }
} }
void SparseNaiveSGDRule::update_value_work(float* w, float* sgd, void SparseNaiveSGDRule::UpdateValueWork(float* w, float* sgd,
const float* push_value, const float* push_value, float scale) {
float scale) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
w[i] -= learning_rate_ * push_value[i]; w[i] -= learning_rate_ * push_value[i];
bound_value(w[i]); BoundValue(w[i]);
} }
} }
void SparseNaiveSGDRule::init_value_work(float* value, float* sgd, void SparseNaiveSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
if (zero_init) { if (zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
value[i] = 0; value[i] = 0;
...@@ -60,12 +59,12 @@ void SparseNaiveSGDRule::init_value_work(float* value, float* sgd, ...@@ -60,12 +59,12 @@ void SparseNaiveSGDRule::init_value_work(float* value, float* sgd,
(local_uniform_real_distribution<float>()(local_random_engine()) * 2 - (local_uniform_real_distribution<float>()(local_random_engine()) * 2 -
1) * 1) *
_initial_range; _initial_range;
bound_value(value[i]); BoundValue(value[i]);
} }
} }
} }
void SparseAdaGradSGDRule::load_config( void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
const SparseCommonSGDRuleParameter& param, size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adagrad_param = param.adagrad(); auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate(); learning_rate_ = adagrad_param.learning_rate();
...@@ -84,42 +83,42 @@ void SparseAdaGradSGDRule::load_config( ...@@ -84,42 +83,42 @@ void SparseAdaGradSGDRule::load_config(
} }
} }
void SparseAdaGradSGDRule::update_value_work(float* w, float* sgd, void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd,
const float* grad, float scale) { const float* grad, float scale) {
float& g2sum = sgd[g2sum_index()]; float& g2sum = sgd[G2SumIndex()];
double add_g2sum = 0; double add_g2sum = 0;
for (int i = 0; i < _embedding_dim; i++) { for (int i = 0; i < _embedding_dim; i++) {
double scaled_grad = grad[i] / scale; double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad * w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum)); sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
bound_value(w[i]); BoundValue(w[i]);
add_g2sum += scaled_grad * scaled_grad; add_g2sum += scaled_grad * scaled_grad;
} }
g2sum += add_g2sum / _embedding_dim; g2sum += add_g2sum / _embedding_dim;
} }
void SparseAdaGradSGDRule::init_value_work(float* value, float* sgd, void SparseAdaGradSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
bound_value(value[i]); BoundValue(value[i]);
} else { } else {
value[i] = value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) * (local_uniform_real_distribution<double>()(local_random_engine()) *
2 - 2 -
1) * 1) *
_initial_range; _initial_range;
bound_value(value[i]); BoundValue(value[i]);
} }
} }
sgd[g2sum_index()] = 0; sgd[G2SumIndex()] = 0;
} }
void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param, void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adagrad_param = param.adagrad(); auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate(); learning_rate_ = adagrad_param.learning_rate();
...@@ -138,38 +137,38 @@ void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param, ...@@ -138,38 +137,38 @@ void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
} }
} }
void StdAdaGradSGDRule::update_value_work(float* w, float* sgd, void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
const float* grad, float scale) { float scale) {
for (int i = 0; i < _embedding_dim; i++) { for (int i = 0; i < _embedding_dim; i++) {
float& g2sum = sgd[g2sum_index() + i]; float& g2sum = sgd[G2SumIndex() + i];
double scaled_grad = grad[i] / scale; double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad * w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum)); sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
bound_value(w[i]); BoundValue(w[i]);
g2sum += scaled_grad * scaled_grad; g2sum += scaled_grad * scaled_grad;
} }
} }
void StdAdaGradSGDRule::init_value_work(float* value, float* sgd, void StdAdaGradSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
bound_value(value[i]); BoundValue(value[i]);
} else { } else {
value[i] = value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) * (local_uniform_real_distribution<double>()(local_random_engine()) *
2 - 2 -
1) * 1) *
_initial_range; _initial_range;
bound_value(value[i]); BoundValue(value[i]);
} }
sgd[g2sum_index() + i] = 0; sgd[G2SumIndex() + i] = 0;
} }
} }
void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param, void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adam_param = param.adam(); auto adam_param = param.adam();
learning_rate_ = adam_param.learning_rate(); learning_rate_ = adam_param.learning_rate();
...@@ -189,12 +188,12 @@ void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param, ...@@ -189,12 +188,12 @@ void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
} }
} }
void SparseAdamSGDRule::update_value_work(float* w, float* sgd, void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
const float* grad, float scale) { float scale) {
float* gsum = sgd + gsum_index(); float* gsum = sgd + GSumIndex();
float* g2sum = sgd + g2sum_index(); float* g2sum = sgd + G2SumIndex();
float* beta1_pow = sgd + beta1_pow_index(); float* beta1_pow = sgd + Beta1PowIndex();
float* beta2_pow = sgd + beta2_pow_index(); float* beta2_pow = sgd + Beta2PowIndex();
const float* g = grad; const float* g = grad;
float lr = learning_rate_; float lr = learning_rate_;
...@@ -209,35 +208,35 @@ void SparseAdamSGDRule::update_value_work(float* w, float* sgd, ...@@ -209,35 +208,35 @@ void SparseAdamSGDRule::update_value_work(float* w, float* sgd,
g2sum[i] = g2sum[i] =
_beta2_decay_rate * g2sum[i] + (1 - _beta2_decay_rate) * g[i] * g[i]; _beta2_decay_rate * g2sum[i] + (1 - _beta2_decay_rate) * g[i] * g[i];
w[i] = w[i] - lr * (gsum[i] / (sqrt(g2sum[i]) + _ada_epsilon)); w[i] = w[i] - lr * (gsum[i] / (sqrt(g2sum[i]) + _ada_epsilon));
bound_value(w[i]); BoundValue(w[i]);
} }
// update beta_pow_decay // update beta_pow_decay
(*beta1_pow) *= _beta1_decay_rate; (*beta1_pow) *= _beta1_decay_rate;
(*beta2_pow) *= _beta2_decay_rate; (*beta2_pow) *= _beta2_decay_rate;
} }
void SparseAdamSGDRule::init_value_work(float* value, float* sgd, void SparseAdamSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
bound_value(value[i]); BoundValue(value[i]);
} else { } else {
value[i] = value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) * (local_uniform_real_distribution<double>()(local_random_engine()) *
2 - 2 -
1) * 1) *
_initial_range; _initial_range;
bound_value(value[i]); BoundValue(value[i]);
} }
} }
// init rule gsum and g2sum // init rule gsum and g2sum
for (int i = gsum_index(); i < beta1_pow_index(); i++) { for (int i = GSumIndex(); i < Beta1PowIndex(); i++) {
sgd[i] = 0.0; sgd[i] = 0.0;
} }
// init beta1_pow and beta2_pow // init beta1_pow and beta2_pow
*(sgd + beta1_pow_index()) = _beta1_decay_rate; *(sgd + Beta1PowIndex()) = _beta1_decay_rate;
*(sgd + beta2_pow_index()) = _beta2_decay_rate; *(sgd + Beta2PowIndex()) = _beta2_decay_rate;
} }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -28,33 +28,33 @@ class SparseValueSGDRule { ...@@ -28,33 +28,33 @@ class SparseValueSGDRule {
public: public:
SparseValueSGDRule() {} SparseValueSGDRule() {}
virtual ~SparseValueSGDRule() {} virtual ~SparseValueSGDRule() {}
virtual void load_config(const SparseCommonSGDRuleParameter& param, virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
_name = param.name(); _name = param.name();
} }
virtual void update_value_work(float* w, float* sgd, const float* push_value, virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale) = 0; float scale) = 0;
virtual void init_value_work(float* value, float* sgd, bool zero_init) = 0; virtual void InitValueWork(float* value, float* sgd, bool zero_init) = 0;
virtual size_t dim() = 0; virtual size_t Dim() = 0;
const std::string& get_name() const { return _name; } const std::string& GetName() const { return _name; }
void init_value(float* value, float* sgd, bool zero_init = true) { void InitValue(float* value, float* sgd, bool zero_init = true) {
init_value_work(value, sgd, zero_init); InitValueWork(value, sgd, zero_init);
} }
void update_value(float* w, float* sgd, const float* push_value, void UpdateValue(float* w, float* sgd, const float* push_value,
float scale = 1) { float scale = 1) {
update_value_work(w, sgd, push_value, scale); UpdateValueWork(w, sgd, push_value, scale);
} }
template <class T> template <class T>
void bound_value(T& w) { // NOLINT void BoundValue(T& w) { // NOLINT
if (!(w >= _min_bound)) { if (!(w >= _min_bound)) {
w = (T)_min_bound; w = (T)_min_bound;
} else if (!(w <= _max_bound)) { } else if (!(w <= _max_bound)) {
w = (T)_max_bound; w = (T)_max_bound;
} }
} }
float& min_bound() { return _min_bound; } float& MinBound() { return _min_bound; }
float& max_bound() { return _max_bound; } float& MaxBound() { return _max_bound; }
protected: protected:
float _min_bound; float _min_bound;
...@@ -70,12 +70,12 @@ REGISTER_PSCORE_REGISTERER(SparseValueSGDRule); ...@@ -70,12 +70,12 @@ REGISTER_PSCORE_REGISTERER(SparseValueSGDRule);
class SparseNaiveSGDRule : public SparseValueSGDRule { class SparseNaiveSGDRule : public SparseValueSGDRule {
public: public:
virtual void load_config(const SparseCommonSGDRuleParameter& param, virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim); size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value, virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale); float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init); virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return 0; } virtual size_t Dim() { return 0; }
private: private:
float learning_rate_; float learning_rate_;
...@@ -83,13 +83,13 @@ class SparseNaiveSGDRule : public SparseValueSGDRule { ...@@ -83,13 +83,13 @@ class SparseNaiveSGDRule : public SparseValueSGDRule {
class SparseAdaGradSGDRule : public SparseValueSGDRule { class SparseAdaGradSGDRule : public SparseValueSGDRule {
public: public:
virtual void load_config(const SparseCommonSGDRuleParameter& param, virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim); size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value, virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale); float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init); virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return 1; } virtual size_t Dim() { return 1; }
size_t g2sum_index() { return 0; } size_t G2SumIndex() { return 0; }
private: private:
float learning_rate_; float learning_rate_;
...@@ -98,13 +98,13 @@ class SparseAdaGradSGDRule : public SparseValueSGDRule { ...@@ -98,13 +98,13 @@ class SparseAdaGradSGDRule : public SparseValueSGDRule {
class StdAdaGradSGDRule : public SparseValueSGDRule { class StdAdaGradSGDRule : public SparseValueSGDRule {
public: public:
virtual void load_config(const SparseCommonSGDRuleParameter& param, virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim); size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value, virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale); float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init); virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return _embedding_dim; } virtual size_t Dim() { return _embedding_dim; }
size_t g2sum_index() { return 0; } size_t G2SumIndex() { return 0; }
private: private:
float learning_rate_; float learning_rate_;
...@@ -113,16 +113,16 @@ class StdAdaGradSGDRule : public SparseValueSGDRule { ...@@ -113,16 +113,16 @@ class StdAdaGradSGDRule : public SparseValueSGDRule {
class SparseAdamSGDRule : public SparseValueSGDRule { class SparseAdamSGDRule : public SparseValueSGDRule {
public: public:
virtual void load_config(const SparseCommonSGDRuleParameter& param, virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim); size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value, virtual void UpdateValueWork(float* w, float* sgd, const float* push_value,
float scale); float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init); virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return _embedding_dim * 2 + 2; } virtual size_t Dim() { return _embedding_dim * 2 + 2; }
size_t gsum_index() { return 0; } size_t GSumIndex() { return 0; }
size_t g2sum_index() { return gsum_index() + _embedding_dim; } size_t G2SumIndex() { return GSumIndex() + _embedding_dim; }
size_t beta1_pow_index() { return g2sum_index() + _embedding_dim; } size_t Beta1PowIndex() { return G2SumIndex() + _embedding_dim; }
size_t beta2_pow_index() { return beta1_pow_index() + 1; } size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
protected: protected:
float learning_rate_; float learning_rate_;
......
...@@ -103,7 +103,6 @@ int32_t Table::InitializeAccessor() { ...@@ -103,7 +103,6 @@ int32_t Table::InitializeAccessor() {
return -1; return -1;
} }
_value_accesor.reset(accessor); _value_accesor.reset(accessor);
// _value_accesor->SetTableInfo(_table_info);
return 0; return 0;
} }
......
...@@ -162,7 +162,6 @@ class Table { ...@@ -162,7 +162,6 @@ class Table {
TableParameter _config; TableParameter _config;
float *_global_lr = nullptr; float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor; std::shared_ptr<ValueAccessor> _value_accesor;
AccessorInfo _table_info;
AfsClient _afs_client; AfsClient _afs_client;
}; };
REGISTER_PSCORE_REGISTERER(Table); REGISTER_PSCORE_REGISTERER(Table);
......
...@@ -18,51 +18,19 @@ ...@@ -18,51 +18,19 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int CommMergeAccessor::Initialize() { return 0; } int CommMergeAccessor::Initialize() {
InitAccessorInfo();
void CommMergeAccessor::SetTableInfo(AccessorInfo &info) {
info.select_dim = SelectDim();
info.select_size = SelectSize();
info.update_dim = UpdateDim();
info.update_size = UpdateSize();
info.fea_dim = fea_dim();
}
size_t CommMergeAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case SELECT_DIM:
return SelectDim();
case SELECT_SIZE:
return SelectSize();
case UPDATE_DIM:
return UpdateDim();
case UPDATE_SIZE:
return UpdateSize();
case FEA_DIM:
return fea_dim();
default:
return 0;
}
return 0; return 0;
} }
// pull value 维度 void CommMergeAccessor::InitAccessorInfo() {
size_t CommMergeAccessor::SelectDim() { return _config.embedx_dim(); } auto embedx_dim = _config.embedx_dim();
_accessor_info.select_dim = embedx_dim;
// pull value 各个维度的size _accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
size_t CommMergeAccessor::SelectDimSize(size_t dim) { return sizeof(float); } _accessor_info.update_dim = embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
// pull value 各维度相加总size _accessor_info.fea_dim = _config.fea_dim();
size_t CommMergeAccessor::SelectSize() { return SelectDim() * sizeof(float); } }
// push value 维度
size_t CommMergeAccessor::UpdateDim() { return _config.embedx_dim(); }
// push value 各个维度的size
size_t CommMergeAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
// push value 各维度相加总size
size_t CommMergeAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
// 判断该value 是否进行shrink // 判断该value 是否进行shrink
bool CommMergeAccessor::Shrink(float * /*value*/) { return false; } bool CommMergeAccessor::Shrink(float * /*value*/) { return false; }
......
...@@ -30,22 +30,8 @@ class CommMergeAccessor : public ValueAccessor { ...@@ -30,22 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {} CommMergeAccessor() {}
virtual ~CommMergeAccessor() {} virtual ~CommMergeAccessor() {}
virtual int Initialize(); virtual int Initialize();
virtual void SetTableInfo(AccessorInfo &info); // 初始化AccessorInfo
virtual size_t GetTableInfo(InfoKey key); virtual void InitAccessorInfo();
// value维度
// pull value维度
size_t SelectDim();
// pull value各个维度的size
size_t SelectDimSize(size_t dim);
// pull value各维度相加总size
size_t SelectSize();
// push value维度
size_t UpdateDim();
// push value各个维度的size
size_t UpdateDimSize(size_t dim);
// push value各维度相加总size
size_t UpdateSize();
size_t fea_dim() { return _config.fea_dim(); }
// 判断该value是否进行shrink // 判断该value是否进行shrink
virtual bool Shrink(float * /*value*/); virtual bool Shrink(float * /*value*/);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
......
...@@ -75,8 +75,8 @@ TEST(downpour_feature_value_accessor_test, test_shrink) { ...@@ -75,8 +75,8 @@ TEST(downpour_feature_value_accessor_test, test_shrink) {
<< acc->common_feature_value.embedx_sgd_dim << " " << acc->common_feature_value.embedx_sgd_dim << " "
<< acc->common_feature_value.Dim() << "\n"; << acc->common_feature_value.Dim() << "\n";
float* value = new float[acc->Dim()]; float* value = new float[acc->GetAccessorInfo().dim];
for (auto i = 0u; i < acc->Dim(); ++i) { for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
value[i] = i * 1.0; value[i] = i * 1.0;
} }
ASSERT_TRUE(!acc->Shrink(value)); ASSERT_TRUE(!acc->Shrink(value));
...@@ -94,8 +94,8 @@ TEST(downpour_feature_value_accessor_test, test_save) { ...@@ -94,8 +94,8 @@ TEST(downpour_feature_value_accessor_test, test_save) {
ASSERT_EQ(acc->Configure(parameter), 0); ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
float* value = new float[acc->Dim()]; float* value = new float[acc->GetAccessorInfo().dim];
for (auto i = 0u; i < acc->Dim(); ++i) { for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
value[i] = i * 1.0; value[i] = i * 1.0;
} }
...@@ -109,7 +109,7 @@ TEST(downpour_feature_value_accessor_test, test_save) { ...@@ -109,7 +109,7 @@ TEST(downpour_feature_value_accessor_test, test_save) {
ASSERT_TRUE(acc->Save(value, 2)); ASSERT_TRUE(acc->Save(value, 2));
VLOG(3) << "test_save:"; VLOG(3) << "test_save:";
for (auto i = 0u; i < acc->Dim(); ++i) { for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
VLOG(3) << value[i]; VLOG(3) << value[i];
} }
} }
...@@ -145,7 +145,7 @@ TEST(downpour_feature_value_accessor_test, test_update) { ...@@ -145,7 +145,7 @@ TEST(downpour_feature_value_accessor_test, test_update) {
ASSERT_EQ(acc->Initialize(), 0); ASSERT_EQ(acc->Initialize(), 0);
VLOG(3) << "dim: " << acc->common_feature_value.Dim() << "\n"; VLOG(3) << "dim: " << acc->common_feature_value.Dim() << "\n";
VLOG(3) << "update_dim: " << acc->GetTableInfo(UPDATE_DIM) << "\n"; VLOG(3) << "update_dim: " << acc->GetAccessorInfo().update_dim << "\n";
const int field_size = 7 + 8; const int field_size = 7 + 8;
const int item_size = 10; const int item_size = 10;
...@@ -162,8 +162,8 @@ TEST(downpour_feature_value_accessor_test, test_update) { ...@@ -162,8 +162,8 @@ TEST(downpour_feature_value_accessor_test, test_update) {
typedef const float* const_float_ptr; typedef const float* const_float_ptr;
const_float_ptr* grad = new const_float_ptr[item_size]; const_float_ptr* grad = new const_float_ptr[item_size];
for (auto i = 0u; i < item_size; ++i) { for (auto i = 0u; i < item_size; ++i) {
float* p = new float[acc->GetTableInfo(UPDATE_DIM)]; float* p = new float[acc->GetAccessorInfo().update_dim];
for (auto j = 0u; j < acc->GetTableInfo(UPDATE_DIM); ++j) { for (auto j = 0u; j < acc->GetAccessorInfo().update_dim; ++j) {
p[j] = i; p[j] = i;
} }
grad[i] = p; grad[i] = p;
...@@ -244,21 +244,21 @@ TEST(downpour_feature_value_accessor_test, test_update) { ...@@ -244,21 +244,21 @@ TEST(downpour_feature_value_accessor_test, test_update) {
v.unseen_days = 0; v.unseen_days = 0;
v.show += push_v.show; v.show += push_v.show;
v.click += push_v.click; v.click += push_v.click;
v.delta_score += acc->show_click_score(push_v.show, push_v.click); v.delta_score += acc->ShowClickScore(push_v.show, push_v.click);
acc->_embed_sgd_rule->update_value(&v.embed_w, &v.embed_g2sum[0], acc->_embed_sgd_rule->UpdateValue(&v.embed_w, &v.embed_g2sum[0],
&push_v.embed_g); &push_v.embed_g);
acc->_embedx_sgd_rule->update_value(&v.embedx_w[0], &v.embedx_g2sum[0], acc->_embedx_sgd_rule->UpdateValue(&v.embedx_w[0], &v.embedx_g2sum[0],
&push_v.embedx_g[0]); &push_v.embedx_g[0]);
float* ptr = new float[acc->Dim()]; float* ptr = new float[acc->GetAccessorInfo().dim];
v.to_array(ptr, parameter.embedx_dim()); v.to_array(ptr, parameter.embedx_dim());
exp_value.push_back(ptr); exp_value.push_back(ptr);
} }
acc->Update(value, grad, item_size); acc->Update(value, grad, item_size);
for (auto i = 0u; i < item_size; ++i) { for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < acc->Dim(); ++j) { for (auto j = 0u; j < acc->GetAccessorInfo().dim; ++j) {
VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " "; VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " ";
ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]); ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]);
} }
...@@ -273,7 +273,7 @@ TEST(downpour_feature_value_accessor_test, test_show_click_score) { ...@@ -273,7 +273,7 @@ TEST(downpour_feature_value_accessor_test, test_show_click_score) {
float show = 10; float show = 10;
float click = 6; float click = 6;
ASSERT_FLOAT_EQ(acc->show_click_score(show, click), 6.8); ASSERT_FLOAT_EQ(acc->ShowClickScore(show, click), 6.8);
} }
TEST(downpour_feature_value_accessor_test, test_string_related) { TEST(downpour_feature_value_accessor_test, test_string_related) {
......
...@@ -31,22 +31,22 @@ TEST(sparse_value_naive_sgd_test, init_and_update) { ...@@ -31,22 +31,22 @@ TEST(sparse_value_naive_sgd_test, init_and_update) {
naive_param->add_weight_bounds(-10.0); naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0); naive_param->add_weight_bounds(10.0);
rule.load_config(param, 10); rule.LoadConfig(param, 10);
// check init_value for zero // check init_value for zero
const int kItemSize = 10; const int kItemSize = 10;
float w[kItemSize]; float w[kItemSize];
float grad[kItemSize]; float grad[kItemSize];
rule.init_value(w, w + 9, true); rule.InitValue(w, w + 9, true);
for (auto i = 0u; i < kItemSize; ++i) { for (auto i = 0u; i < kItemSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0); ASSERT_FLOAT_EQ(w[i], 0);
} }
// check init_value for random // check init_value for random
rule.init_value(w, w + 9, false); rule.InitValue(w, w + 9, false);
for (auto i = 0u; i < kItemSize; ++i) { for (auto i = 0u; i < kItemSize; ++i) {
ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound()); ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound());
} }
// check update_value for one field // check update_value for one field
...@@ -59,7 +59,7 @@ TEST(sparse_value_naive_sgd_test, init_and_update) { ...@@ -59,7 +59,7 @@ TEST(sparse_value_naive_sgd_test, init_and_update) {
float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, -0.500000, float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, -0.500000,
-0.600000, -0.700000, -0.800000, -0.900000, -1.000000}; -0.600000, -0.700000, -0.800000, -0.900000, -1.000000};
const float* ptr_grad = grad; const float* ptr_grad = grad;
rule.update_value(w, w + 9, ptr_grad); rule.UpdateValue(w, w + 9, ptr_grad);
for (auto i = 0u; i < kItemSize; ++i) { for (auto i = 0u; i < kItemSize; ++i) {
VLOG(3) << w[i] << "\n"; VLOG(3) << w[i] << "\n";
...@@ -78,14 +78,14 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { ...@@ -78,14 +78,14 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
adagrad_param->add_weight_bounds(-10.0); adagrad_param->add_weight_bounds(-10.0);
adagrad_param->add_weight_bounds(10.0); adagrad_param->add_weight_bounds(10.0);
rule.load_config(param, 10); rule.LoadConfig(param, 10);
// check init_value for zero // check init_value for zero
const int kValueSize = 11; const int kValueSize = 11;
int kEmbSize = 10; int kEmbSize = 10;
float w[kValueSize]; float w[kValueSize];
rule.init_value(w, w + 10, true); rule.InitValue(w, w + 10, true);
for (auto i = 0u; i < kEmbSize; ++i) { for (auto i = 0u; i < kEmbSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0); ASSERT_FLOAT_EQ(w[i], 0);
...@@ -93,9 +93,9 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { ...@@ -93,9 +93,9 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
ASSERT_FLOAT_EQ(w[kEmbSize], 0); ASSERT_FLOAT_EQ(w[kEmbSize], 0);
// check init_value for random // check init_value for random
rule.init_value(w, w + 10, false); rule.InitValue(w, w + 10, false);
for (auto i = 0u; i < kEmbSize; ++i) { for (auto i = 0u; i < kEmbSize; ++i) {
ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound()); ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound());
} }
ASSERT_FLOAT_EQ(w[kEmbSize], 0); ASSERT_FLOAT_EQ(w[kEmbSize], 0);
...@@ -110,7 +110,7 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { ...@@ -110,7 +110,7 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
} }
const float* ptr_grad = grad; const float* ptr_grad = grad;
rule.update_value(w, w + 10, ptr_grad); rule.UpdateValue(w, w + 10, ptr_grad);
float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, float label[] = {-0.100000, -0.200000, -0.300000, -0.400000,
-0.500000, -0.600000, -0.700000, -0.800000, -0.500000, -0.600000, -0.700000, -0.800000,
-0.900000, -1.000000, 38.500000}; -0.900000, -1.000000, 38.500000};
...@@ -140,33 +140,33 @@ TEST(downpour_sparse_adam_test, test_init_and_update) { ...@@ -140,33 +140,33 @@ TEST(downpour_sparse_adam_test, test_init_and_update) {
SparseAdamSGDRule rule; SparseAdamSGDRule rule;
rule.load_config(param, embed_dim); rule.LoadConfig(param, embed_dim);
// check init_value for zero // check init_value for zero
const int rule_dim = const int rule_dim =
rule.dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam rule.Dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam
const int value_dim = embed_dim + rule_dim; // total dims of w + rule const int value_dim = embed_dim + rule_dim; // total dims of w + rule
float* value = new float[value_dim]; float* value = new float[value_dim];
rule.init_value(value, value + embed_dim, true); rule.InitValue(value, value + embed_dim, true);
for (auto i = 0u; i < rule.beta1_pow_index(); ++i) { for (auto i = 0u; i < rule.Beta1PowIndex(); ++i) {
ASSERT_FLOAT_EQ(value[i], 0); ASSERT_FLOAT_EQ(value[i], 0);
} }
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9); ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta1PowIndex()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999); ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta2PowIndex()), 0.999);
// check init_value for random // check init_value for random
rule.init_value(value, value + embed_dim, false); rule.InitValue(value, value + embed_dim, false);
for (auto i = 0u; i < embed_dim; ++i) { for (auto i = 0u; i < embed_dim; ++i) {
ASSERT_TRUE(value[i] >= rule.min_bound() && value[i] <= rule.max_bound()); ASSERT_TRUE(value[i] >= rule.MinBound() && value[i] <= rule.MaxBound());
} }
for (auto i = rule.gsum_index(); i < rule.beta1_pow_index(); ++i) { for (auto i = rule.GSumIndex(); i < rule.Beta1PowIndex(); ++i) {
ASSERT_FLOAT_EQ(value[i + embed_dim], 0); ASSERT_FLOAT_EQ(value[i + embed_dim], 0);
} }
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9); ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta1PowIndex()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999); ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta2PowIndex()), 0.999);
// check update_value // check update_value
rule.init_value(value, value + embed_dim, true); rule.InitValue(value, value + embed_dim, true);
float* grad = new float[embed_dim]; float* grad = new float[embed_dim];
for (auto i = 0u; i < embed_dim; ++i) { for (auto i = 0u; i < embed_dim; ++i) {
grad[i] = (i + 1) * 1.0; grad[i] = (i + 1) * 1.0;
...@@ -181,7 +181,7 @@ TEST(downpour_sparse_adam_test, test_init_and_update) { ...@@ -181,7 +181,7 @@ TEST(downpour_sparse_adam_test, test_init_and_update) {
0.0249996781, 0.0359995365, 0.0489993691, 0.063999176, 0.0249996781, 0.0359995365, 0.0489993691, 0.063999176,
0.0809989572, 0.0999987125, 0.809999943, 0.998001039}; 0.0809989572, 0.0999987125, 0.809999943, 0.998001039};
rule.update_value(value, value + embed_dim, grad); rule.UpdateValue(value, value + embed_dim, grad);
for (auto i = 0u; i < value_dim; ++i) { // check update for (auto i = 0u; i < value_dim; ++i) { // check update
ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i; ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i;
......
...@@ -1668,7 +1668,7 @@ class Fleet(object): ...@@ -1668,7 +1668,7 @@ class Fleet(object):
opt_info["mpi_rank"] = self.worker_index() opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items( for k, v in self._user_defined_strategy.trainer_desc_configs.items(
): ):
if v: if v or k not in opt_info:
opt_info[k] = v opt_info[k] = v
program._fleet_opt = opt_info program._fleet_opt = opt_info
...@@ -1745,7 +1745,7 @@ class Fleet(object): ...@@ -1745,7 +1745,7 @@ class Fleet(object):
opt_info["mpi_rank"] = self.worker_index() opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items( for k, v in self._user_defined_strategy.trainer_desc_configs.items(
): ):
if v: if v or k not in opt_info:
opt_info[k] = v opt_info[k] = v
program._fleet_opt = opt_info program._fleet_opt = opt_info
# print("fleet base opt info:", id(program), program._fleet_opt) # print("fleet base opt info:", id(program), program._fleet_opt)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册