未验证 提交 6d4d774d 编写于 作者: Y yaoxuefeng 提交者: GitHub

add downpour_ctr_accessor (#39341)

add downpour_ctr_accessor
上级 6097aefb
...@@ -43,10 +43,12 @@ set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPI ...@@ -43,10 +43,12 @@ set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPI
set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(downpour_ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto) cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_double_accessor SRCS ctr_double_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) cc_library(ctr_double_accessor SRCS ctr_double_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(downpour_ctr_accessor SRCS downpour_ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -153,6 +153,8 @@ class ValueAccessor { ...@@ -153,6 +153,8 @@ class ValueAccessor {
} }
virtual float get_field(float* value, const std::string& name) { return 0.0; } virtual float get_field(float* value, const std::string& name) { return 0.0; }
#define DEFINE_GET_INDEX(class, field) \
virtual int get_##field##_index() override { return class ::field##_index(); }
protected: protected:
size_t _value_size; size_t _value_size;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int DownpourCtrAccessor::initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
_config.embedx_dim());
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
set_time_decay_rates();
return 0;
}
size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim);
}
size_t DownpourCtrAccessor::dim_size(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim_size(dim, embedx_dim);
}
size_t DownpourCtrAccessor::size() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::size(embedx_dim);
}
size_t DownpourCtrAccessor::mf_size() {
return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t DownpourCtrAccessor::select_dim() {
auto embedx_dim = _config.embedx_dim();
return 3 + embedx_dim;
}
size_t DownpourCtrAccessor::select_dim_size(size_t dim) {
return sizeof(float);
}
size_t DownpourCtrAccessor::select_size() {
return select_dim() * sizeof(float);
}
// push value
size_t DownpourCtrAccessor::update_dim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t DownpourCtrAccessor::update_dim_size(size_t dim) {
return sizeof(float);
}
size_t DownpourCtrAccessor::update_size() {
return update_dim() * sizeof(float);
}
bool DownpourCtrAccessor::shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
int16_t day_diff = _day_id - unseen_days;
if (day_diff < 0 || day_diff > delete_after_unseen_days) {
return true;
}
auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff];
auto click_right =
DownpourCtrFeatureValue::click(value) * _time_decay_rates[day_diff];
// shrink after
auto score = show_click_score(show_right, click_right);
if (score < delete_threshold) {
return true;
}
return false;
}
void DownpourCtrAccessor::set_day_id(int day_id) { _day_id = day_id; }
int DownpourCtrAccessor::get_day_id() { return _day_id; }
bool DownpourCtrAccessor::save_ssd(float* value) {
if (_day_id == 0) {
return true;
}
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
if (unseen_days == 0) {
return false;
}
// for the origin load (eg. unseen_days = 0-15)
if (unseen_days < _config.ctr_accessor_param().delta_keep_days()) {
unseen_days = _day_id - unseen_days;
}
int16_t day_diff = _day_id - unseen_days;
if (day_diff > _ssd_unseenday_threshold) {
return true;
}
return false;
}
// bool DownpourCtrAccessor::save_cache(
// float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
// int16_t day_diff = _day_id - unseen_days;
// if (show_click_score(DownpourCtrFeatureValue::show(value),
// DownpourCtrFeatureValue::click(value)) >= base_threshold
// && day_diff <= delta_keep_days) {
// return DownpourCtrFeatureValue::show(value) > global_cache_threshold;
// }
// return false;
// }
bool DownpourCtrAccessor::save(float* value, int param) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
int16_t day_diff = _day_id - unseen_days;
auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff];
auto click_right =
DownpourCtrFeatureValue::click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold &&
day_diff <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
DownpourCtrFeatureValue::delta_score(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// DownpourCtrFeatureValue::show(value) *= _show_click_decay_rate;
// DownpourCtrFeatureValue::click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry
// DownpourCtrFeatureValue::unseen_days(value)++;
return true;
}
default:
return true;
};
}
void DownpourCtrAccessor::update_stat_after_save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
int16_t day_diff = _day_id - unseen_days;
auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff];
auto click_right =
DownpourCtrFeatureValue::click(value) * _time_decay_rates[day_diff];
if (show_click_score(show_right, click_right) >= base_threshold &&
DownpourCtrFeatureValue::delta_score(value) >= delta_threshold &&
day_diff <= delta_keep_days) {
DownpourCtrFeatureValue::delta_score(value) = 0;
}
}
return;
// case 3:
// {
// DownpourCtrFeatureValue::unseen_days(value)++;
// }
// return;
default:
return;
};
}
int32_t DownpourCtrAccessor::create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[DownpourCtrFeatureValue::unseen_days_index()] = 0;
value[DownpourCtrFeatureValue::delta_score_index()] = 0;
value[DownpourCtrFeatureValue::show_index()] = 0;
value[DownpourCtrFeatureValue::click_index()] = 0;
value[DownpourCtrFeatureValue::slot_index()] = -1;
_embed_sgd_rule->init_value(
value + DownpourCtrFeatureValue::embed_w_index(),
value + DownpourCtrFeatureValue::embed_g2sum_index(), true);
_embedx_sgd_rule->init_value(
value + DownpourCtrFeatureValue::embedx_w_index(),
value + DownpourCtrFeatureValue::embedx_g2sum_index());
}
return 0;
}
bool DownpourCtrAccessor::need_extend_mf(float* value) {
float show = value[DownpourCtrFeatureValue::show_index()];
float click = value[DownpourCtrFeatureValue::click_index()];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
//+ click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
bool DownpourCtrAccessor::has_mf(size_t size) {
return size > DownpourCtrFeatureValue::embedx_g2sum_index();
}
// from DownpourCtrFeatureValue to DownpourCtrPullValue
int32_t DownpourCtrAccessor::select(float** select_values, const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
float* value = const_cast<float*>(values[value_item]);
select_value[DownpourCtrPullValue::show_index()] =
value[DownpourCtrFeatureValue::show_index()];
select_value[DownpourCtrPullValue::click_index()] =
value[DownpourCtrFeatureValue::click_index()];
select_value[DownpourCtrPullValue::embed_w_index()] =
value[DownpourCtrFeatureValue::embed_w_index()];
memcpy(select_value + DownpourCtrPullValue::embedx_w_index(),
value + DownpourCtrFeatureValue::embedx_w_index(),
embedx_dim * sizeof(float));
}
return 0;
}
// from DownpourCtrPushValue to DownpourCtrPushValue
// first dim: item
// second dim: field num
int32_t DownpourCtrAccessor::merge(float** update_values,
const float** other_update_values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
size_t total_dim = DownpourCtrPushValue::dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) {
if (i != DownpourCtrPushValue::slot_index()) {
update_value[i] += other_update_value[i];
}
}
}
return 0;
}
// from DownpourCtrPushValue to DownpourCtrFeatureValue
// first dim: item
// second dim: field num
int32_t DownpourCtrAccessor::update(float** update_values,
const float** push_values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* push_value = push_values[value_item];
float push_show = push_value[DownpourCtrPushValue::show_index()];
float push_click = push_value[DownpourCtrPushValue::click_index()];
float slot = push_value[DownpourCtrPushValue::slot_index()];
update_value[DownpourCtrFeatureValue::show_index()] += push_show;
update_value[DownpourCtrFeatureValue::click_index()] += push_click;
update_value[DownpourCtrFeatureValue::slot_index()] = slot;
update_value[DownpourCtrFeatureValue::delta_score_index()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff();
update_value[DownpourCtrFeatureValue::unseen_days_index()] = 0;
_embed_sgd_rule->update_value(
update_value + DownpourCtrFeatureValue::embed_w_index(),
update_value + DownpourCtrFeatureValue::embed_g2sum_index(),
push_value + DownpourCtrPushValue::embed_g_index(), push_show);
_embedx_sgd_rule->update_value(
update_value + DownpourCtrFeatureValue::embedx_w_index(),
update_value + DownpourCtrFeatureValue::embedx_g2sum_index(),
push_value + DownpourCtrPushValue::embedx_g_index(), push_show);
}
return 0;
}
bool DownpourCtrAccessor::create_value(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
auto show = DownpourCtrPushValue::show(const_cast<float*>(value));
auto click = DownpourCtrPushValue::click(const_cast<float*>(value));
auto score = show_click_score(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
float DownpourCtrAccessor::show_click_score(float show, float click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string DownpourCtrAccessor::parse_to_string(const float* v,
int param_size) {
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5] << " " << v[6];
auto show = DownpourCtrFeatureValue::show(const_cast<float*>(v));
auto click = DownpourCtrFeatureValue::click(const_cast<float*>(v));
auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold() && param_size > 7) {
os << " " << v[7];
for (auto i = 0; i < _config.embedx_dim(); ++i) {
os << " " << v[8 + i];
}
}
return os.str();
}
int DownpourCtrAccessor::parse_from_string(const std::string& str,
float* value) {
int embedx_dim = _config.embedx_dim();
float data_buff[dim()];
float* data_buff_ptr = data_buff;
_embedx_sgd_rule->init_value(
data_buff_ptr + DownpourCtrFeatureValue::embedx_w_index(),
data_buff_ptr + DownpourCtrFeatureValue::embedx_g2sum_index());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
// no slot, embedx
int value_dim = dim();
int embedx_g2sum_index = DownpourCtrFeatureValue::embedx_g2sum_index();
value[DownpourCtrFeatureValue::slot_index()] = -1;
// other case
if (str_len == (value_dim - 1)) {
memcpy(value, data_buff_ptr, (embedx_g2sum_index - 1) * sizeof(float));
memcpy(value + embedx_g2sum_index, data_buff_ptr + embedx_g2sum_index - 1,
(embedx_dim + 1) * sizeof(float));
} else {
memcpy(value, data_buff_ptr, str_len * sizeof(float));
}
if (str_len == (value_dim - 1) || str_len == 6) {
str_len += 1;
}
return str_len;
}
void DownpourCtrAccessor::set_time_decay_rates() {
//根据unseen_days的天数来初始化_time_decay_rates大小和对应的衰减率
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
_time_decay_rates.assign(delete_after_unseen_days + 1, 0.0);
for (int i = 0; i <= delete_after_unseen_days; ++i) {
_time_decay_rates[i] = pow(_show_click_decay_rate, i);
}
}
void DownpourCtrAccessor::update_time_decay(float* value,
bool is_update_seen_day) {
// 根据day_id 来进行show click 衰减和unseen_day 更新;unseen_day
// 为上次出现的dayid
if (_day_id == 0) {
return;
}
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
if (unseen_days == 0) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
return;
}
// for the origin load (unseenday = 0 -15)
if (unseen_days < _config.ctr_accessor_param().delete_after_unseen_days()) {
// pull
if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
return;
// save 舍弃原始的unseenday,都变为上一天出现,保证show/click不被重复decay
} else {
DownpourCtrFeatureValue::unseen_days(value) = _day_id - 1;
}
}
int16_t day_diff = _day_id - unseen_days;
if (day_diff < 0) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
return;
}
if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) {
return;
}
DownpourCtrFeatureValue::show(value) *= _time_decay_rates[day_diff];
DownpourCtrFeatureValue::click(value) *= _time_decay_rates[day_diff];
if (is_update_seen_day) {
DownpourCtrFeatureValue::unseen_days(value) = _day_id;
}
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
/**
* @brief Accessor for unit
**/
class DownpourCtrAccessor : public ValueAccessor {
public:
struct DownpourCtrFeatureValue {
/*
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
float embed_g2sum;
float slot;
float embedx_g2sum;
std::vector<float> embedx_w;
*/
static int dim(int embedx_dim) { return 8 + embedx_dim; }
static int dim_size(size_t dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
static int unseen_days_index() { return 0; }
static int delta_score_index() {
return DownpourCtrFeatureValue::unseen_days_index() + 1;
}
static int show_index() {
return DownpourCtrFeatureValue::delta_score_index() + 1;
}
static int click_index() {
return DownpourCtrFeatureValue::show_index() + 1;
}
static int embed_w_index() {
return DownpourCtrFeatureValue::click_index() + 1;
}
static int embed_g2sum_index() {
return DownpourCtrFeatureValue::embed_w_index() + 1;
}
static int slot_index() {
return DownpourCtrFeatureValue::embed_g2sum_index() + 1;
}
static int embedx_g2sum_index() {
return DownpourCtrFeatureValue::slot_index() + 1;
}
static int embedx_w_index() {
return DownpourCtrFeatureValue::embedx_g2sum_index() + 1;
}
static float& unseen_days(float* val) {
return val[DownpourCtrFeatureValue::unseen_days_index()];
}
static float& delta_score(float* val) {
return val[DownpourCtrFeatureValue::delta_score_index()];
}
static float& show(float* val) {
return val[DownpourCtrFeatureValue::show_index()];
}
static float& click(float* val) {
return val[DownpourCtrFeatureValue::click_index()];
}
static float& slot(float* val) {
return val[DownpourCtrFeatureValue::slot_index()];
}
static float& embed_w(float* val) {
return val[DownpourCtrFeatureValue::embed_w_index()];
}
static float& embed_g2sum(float* val) {
return val[DownpourCtrFeatureValue::embed_g2sum_index()];
}
static float& embedx_g2sum(float* val) {
return val[DownpourCtrFeatureValue::embedx_g2sum_index()];
}
static float* embedx_w(float* val) {
return (val + DownpourCtrFeatureValue::embedx_w_index());
}
};
struct DownpourCtrPushValue {
/*
float slot;
float show;
float click;
float embed_g;
std::vector<float> embedx_g;
*/
static int dim(int embedx_dim) { return 4 + embedx_dim; }
static int dim_size(int dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
static int slot_index() { return 0; }
static int show_index() { return DownpourCtrPushValue::slot_index() + 1; }
static int click_index() { return DownpourCtrPushValue::show_index() + 1; }
static int embed_g_index() {
return DownpourCtrPushValue::click_index() + 1;
}
static int embedx_g_index() {
return DownpourCtrPushValue::embed_g_index() + 1;
}
static float& slot(float* val) { return val[0]; }
static float& show(float* val) { return val[1]; }
static float& click(float* val) { return val[2]; }
static float& embed_g(float* val) { return val[3]; }
static float* embedx_g(float* val) { return val + 4; }
};
struct DownpourCtrPullValue {
/*
float show;
float click;
float embed_w;
std::vector<float> embedx_w;
*/
static int dim(int embedx_dim) { return 3 + embedx_dim; }
static int dim_size(size_t dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
static int show_index() { return 0; }
static int click_index() { return 1; }
static int embed_w_index() { return 2; }
static int embedx_w_index() { return 3; }
static float& show(float* val) {
return val[DownpourCtrPullValue::show_index()];
}
static float& click(float* val) {
return val[DownpourCtrPullValue::click_index()];
}
static float& embed_w(float* val) {
return val[DownpourCtrPullValue::embed_w_index()];
}
static float* embedx_w(float* val) {
return val + DownpourCtrPullValue::embedx_w_index();
}
};
DownpourCtrAccessor() {}
virtual ~DownpourCtrAccessor() {}
virtual int initialize();
// value维度
virtual size_t dim();
// value各个维度的size
virtual size_t dim_size(size_t dim);
// value各维度相加总size
virtual size_t size();
// value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size();
// pull value维度
virtual size_t select_dim();
// pull value各个维度的size
virtual size_t select_dim_size(size_t dim);
// pull value各维度相加总size
virtual size_t select_size();
// push value维度
virtual size_t update_dim();
// push value各个维度的size
virtual size_t update_dim_size(size_t dim);
// push value各维度相加总size
virtual size_t update_size();
// 判断该value是否进行shrink
virtual bool shrink(float* value);
// 判断该value是否保存到ssd
virtual bool save_ssd(float* value);
virtual bool need_extend_mf(float* value);
virtual bool has_mf(size_t size);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 3, save all feature with time decay
virtual bool save(float* value, int param) override;
// update delta_score and unseen_days after save
virtual void update_stat_after_save(float* value, int param) override;
// virtual bool save_cache(float* value, int param, double
// global_cache_threshold) override;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual int32_t create(float** value, size_t num);
// 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values,
size_t num);
// 将update_values聚合到一起
virtual int32_t merge(float** update_values,
const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values,
size_t num);
virtual std::string parse_to_string(const float* value, int param) override;
virtual int32_t parse_from_string(const std::string& str, float* v) override;
virtual bool create_value(int type, const float* value);
//这个接口目前只用来取show
virtual float get_field(float* value, const std::string& name) override {
CHECK(name == "show");
if (name == "show") {
auto unseen_days = DownpourCtrFeatureValue::unseen_days(value);
int16_t day_diff = _day_id - unseen_days;
auto show_right =
DownpourCtrFeatureValue::show(value) * _time_decay_rates[day_diff];
return (float)show_right;
}
return 0.0;
}
// DEFINE_GET_INDEX(DownpourCtrFeatureValue, show)
// DEFINE_GET_INDEX(DownpourCtrFeatureValue, click)
// DEFINE_GET_INDEX(DownpourCtrFeatureValue, embed_w)
// DEFINE_GET_INDEX(DownpourCtrFeatureValue, embedx_w)
virtual void update_time_decay(float* value, bool is_update_seen_day);
virtual void set_day_id(int day_id);
virtual int get_day_id();
bool test_func() { return false; }
private:
float show_click_score(float show, float click);
void set_time_decay_rates();
private:
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold;
std::vector<double> _time_decay_rates;
int _day_id;
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册