未验证 提交 cea1ba88 编写于 作者: Z zhaocaibei123 提交者: GitHub

add ctr accessor (#36601)

上级 19b02d95
...@@ -119,13 +119,11 @@ message TableParameter { ...@@ -119,13 +119,11 @@ message TableParameter {
message TableAccessorParameter { message TableAccessorParameter {
optional string accessor_class = 1; optional string accessor_class = 1;
// optional SparseSGDRuleParameter sparse_sgd_param = 2;
optional uint32 fea_dim = 4 [ default = 11 ]; optional uint32 fea_dim = 4 [ default = 11 ];
optional uint32 embedx_dim = 5 [ default = 8 ]; optional uint32 embedx_dim = 5 [ default = 8 ];
optional uint32 embedx_threshold = 6 [ default = 10 ]; optional uint32 embedx_threshold = 6 [ default = 10 ];
optional CtrAccessorParameter ctr_accessor_param = 7; optional CtrAccessorParameter ctr_accessor_param = 7;
repeated TableAccessorSaveParameter table_accessor_save_param = 8; repeated TableAccessorSaveParameter table_accessor_save_param = 8;
// optional SparseCommonSGDRuleParameter sparse_commonsgd_param = 9;
optional SparseCommonSGDRuleParameter embed_sgd_param = 10; optional SparseCommonSGDRuleParameter embed_sgd_param = 10;
optional SparseCommonSGDRuleParameter embedx_sgd_param = 11; optional SparseCommonSGDRuleParameter embedx_sgd_param = 11;
} }
...@@ -182,13 +180,6 @@ message TableAccessorSaveParameter { ...@@ -182,13 +180,6 @@ message TableAccessorSaveParameter {
optional string deconverter = 3; optional string deconverter = 3;
} }
// message SparseSGDRuleParameter {
// optional double learning_rate = 1 [default = 0.05];
// optional double initial_g2sum = 2 [default = 3.0];
// optional double initial_range = 3 [default = 0.0001];
// repeated float weight_bounds = 4;
//}
message SparseCommonSGDRuleParameter { message SparseCommonSGDRuleParameter {
optional string name = 1; optional string name = 1;
optional SparseNaiveSGDRuleParameter naive = 2; optional SparseNaiveSGDRuleParameter naive = 2;
......
...@@ -36,7 +36,8 @@ cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto exec ...@@ -36,7 +36,8 @@ cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto exec
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto) cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost ctr_accessor)
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost sparse_sgd_rule)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/ctr_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int CtrCommonAccessor::initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->dim();
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
return 0;
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) {
auto embedx_dim = _config.embedx_dim();
return common_feature_value.dim_size(dim, embedx_dim);
}
size_t CtrCommonAccessor::size() { return common_feature_value.size(); }
size_t CtrCommonAccessor::mf_size() {
return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) *
sizeof(float); // embedx embedx_g2sum
}
// pull value
size_t CtrCommonAccessor::select_dim() {
auto embedx_dim = _config.embedx_dim();
return 1 + embedx_dim;
}
size_t CtrCommonAccessor::select_dim_size(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::select_size() { return select_dim() * sizeof(float); }
// push value
size_t CtrCommonAccessor::update_dim() {
auto embedx_dim = _config.embedx_dim();
return 4 + embedx_dim;
}
size_t CtrCommonAccessor::update_dim_size(size_t dim) { return sizeof(float); }
size_t CtrCommonAccessor::update_size() { return update_dim() * sizeof(float); }
bool CtrCommonAccessor::shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
common_feature_value.show(value) *= _show_click_decay_rate;
common_feature_value.click(value) *= _show_click_decay_rate;
// shrink after
auto score = show_click_score(common_feature_value.show(value),
common_feature_value.click(value));
auto unseen_days = common_feature_value.unseen_days(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool CtrCommonAccessor::save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
if (show_click_score(common_feature_value.show(value),
common_feature_value.click(value)) >=
base_threshold &&
common_feature_value.delta_score(value) >= delta_threshold &&
common_feature_value.unseen_days(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
common_feature_value.delta_score(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// common_feature_value.unseen_days(value)++;
return true;
}
// save revert batch_model
case 5: {
return true;
}
default:
return true;
}
}
void CtrCommonAccessor::update_stat_after_save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
if (show_click_score(common_feature_value.show(value),
common_feature_value.click(value)) >=
base_threshold &&
common_feature_value.delta_score(value) >= delta_threshold &&
common_feature_value.unseen_days(value) <= delta_keep_days) {
common_feature_value.delta_score(value) = 0;
}
}
return;
case 3: {
common_feature_value.unseen_days(value)++;
}
return;
default:
return;
}
}
int32_t CtrCommonAccessor::create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[common_feature_value.unseen_days_index()] = 0;
value[common_feature_value.delta_score_index()] = 0;
value[common_feature_value.show_index()] = 0;
value[common_feature_value.click_index()] = 0;
value[common_feature_value.slot_index()] = -1;
_embed_sgd_rule->init_value(
value + common_feature_value.embed_w_index(),
value + common_feature_value.embed_g2sum_index());
_embedx_sgd_rule->init_value(
value + common_feature_value.embedx_w_index(),
value + common_feature_value.embedx_g2sum_index(), false);
}
return 0;
}
bool CtrCommonAccessor::need_extend_mf(float* value) {
float show = value[common_feature_value.show_index()];
float click = value[common_feature_value.click_index()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
bool CtrCommonAccessor::has_mf(size_t size) {
return size > common_feature_value.embedx_g2sum_index();
}
// from CommonFeatureValue to CtrCommonPullValue
int32_t CtrCommonAccessor::select(float** select_values, const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
const float* value = values[value_item];
select_value[CtrCommonPullValue::embed_w_index()] =
value[common_feature_value.embed_w_index()];
memcpy(select_value + CtrCommonPullValue::embedx_w_index(),
value + common_feature_value.embedx_w_index(),
embedx_dim * sizeof(float));
}
return 0;
}
// from CtrCommonPushValue to CtrCommonPushValue
// first dim: item
// second dim: field num
int32_t CtrCommonAccessor::merge(float** update_values,
const float** other_update_values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
size_t total_dim = CtrCommonPushValue::dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) {
if (i != CtrCommonPushValue::slot_index()) {
update_value[i] += other_update_value[i];
}
}
}
return 0;
}
// from CtrCommonPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
int32_t CtrCommonAccessor::update(float** update_values,
const float** push_values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* push_value = push_values[value_item];
float push_show = push_value[CtrCommonPushValue::show_index()];
float push_click = push_value[CtrCommonPushValue::click_index()];
float slot = push_value[CtrCommonPushValue::slot_index()];
update_value[common_feature_value.show_index()] += push_show;
update_value[common_feature_value.click_index()] += push_click;
update_value[common_feature_value.slot_index()] = slot;
update_value[common_feature_value.delta_score_index()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.unseen_days_index()] = 0;
_embed_sgd_rule->update_value(
update_value + common_feature_value.embed_w_index(),
update_value + common_feature_value.embed_g2sum_index(),
push_value + CtrCommonPushValue::embed_g_index());
_embedx_sgd_rule->update_value(
update_value + common_feature_value.embedx_w_index(),
update_value + common_feature_value.embedx_g2sum_index(),
push_value + CtrCommonPushValue::embedx_g_index());
}
return 0;
}
bool CtrCommonAccessor::create_value(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrCommonPushValue::show_const(value);
auto click = CtrCommonPushValue::click_const(value);
auto score = show_click_score(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
float CtrCommonAccessor::show_click_score(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string CtrCommonAccessor::parse_to_string(const float* v, int param) {
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5];
for (int i = common_feature_value.embed_g2sum_index();
i < common_feature_value.embedx_w_index(); i++) {
os << " " << v[i];
}
auto show = common_feature_value.show_const(v);
auto click = common_feature_value.click_const(v);
auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold()) {
for (auto i = common_feature_value.embedx_w_index();
i < common_feature_value.dim(); ++i) {
os << " " << v[i];
}
}
return os.str();
}
int CtrCommonAccessor::parse_from_string(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim();
_embedx_sgd_rule->init_value(
value + common_feature_value.embedx_w_index(),
value + common_feature_value.embedx_g2sum_index());
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
class CtrCommonAccessor : public ValueAccessor {
public:
struct CtrCommonFeatureValue {
/*
float slot;
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
std::vector<float> embed_g2sum;
std::vector<float> embedx_w;
std::<vector>float embedx_g2sum;
*/
int dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
int dim_size(size_t dim, int embedx_dim) { return sizeof(float); }
int size() { return dim() * sizeof(float); }
int slot_index() { return 0; }
int unseen_days_index() { return slot_index() + 1; }
int delta_score_index() { return unseen_days_index() + 1; }
int show_index() { return delta_score_index() + 1; }
int click_index() { return show_index() + 1; }
int embed_w_index() { return click_index() + 1; }
int embed_g2sum_index() { return embed_w_index() + 1; }
int embedx_w_index() { return embed_g2sum_index() + embed_sgd_dim; }
int embedx_g2sum_index() { return embedx_w_index() + embedx_dim; }
float& unseen_days(float* val) { return val[unseen_days_index()]; }
float& delta_score(float* val) { return val[delta_score_index()]; }
float& show(float* val) { return val[show_index()]; }
float& click(float* val) { return val[click_index()]; }
float& slot(float* val) { return val[slot_index()]; }
float& embed_w(float* val) { return val[embed_w_index()]; }
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& embedx_w(float* val) { return val[embedx_w_index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
float show_const(const float* val) {
float s = val[show_index()];
return s;
}
float click_const(const float* val) {
float c = val[click_index()];
return c;
}
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
};
struct CtrCommonPushValue {
/*
float slot;
float show;
float click;
float embed_g;
std::vector<float> embedx_g;
*/
static int dim(int embedx_dim) { return 4 + embedx_dim; }
static int dim_size(int dim, int embedx_dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
static int slot_index() { return 0; }
static int show_index() { return CtrCommonPushValue::slot_index() + 1; }
static int click_index() { return CtrCommonPushValue::show_index() + 1; }
static int embed_g_index() { return CtrCommonPushValue::click_index() + 1; }
static int embedx_g_index() {
return CtrCommonPushValue::embed_g_index() + 1;
}
static float& slot(float* val) {
return val[CtrCommonPushValue::slot_index()];
}
static float& show(float* val) {
return val[CtrCommonPushValue::show_index()];
}
static float& click(float* val) {
return val[CtrCommonPushValue::click_index()];
}
static float show_const(const float* val) {
float s = val[show_index()];
return s;
}
static float click_const(const float* val) {
float c = val[click_index()];
return c;
}
static float& embed_g(float* val) {
return val[CtrCommonPushValue::embed_g_index()];
}
static float* embedx_g(float* val) {
return val + CtrCommonPushValue::embedx_g_index();
}
};
struct CtrCommonPullValue {
/*
float embed_w;
std::vector<float> embedx_w;
*/
static int dim(int embedx_dim) { return 1 + embedx_dim; }
static int dim_size(size_t dim) { return sizeof(float); }
static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
static int embed_w_index() { return 0; }
static int embedx_w_index() { return 1; }
static float& embed_w(float* val) {
return val[CtrCommonPullValue::embed_w_index()];
}
static float* embedx_w(float* val) {
return val + CtrCommonPullValue::embedx_w_index();
}
};
CtrCommonAccessor() {}
virtual int initialize();
virtual ~CtrCommonAccessor() {}
// value维度
virtual size_t dim();
// value各个维度的size
virtual size_t dim_size(size_t dim);
// value各维度相加总size
virtual size_t size();
// value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size();
// pull value维度
virtual size_t select_dim();
// pull value各个维度的size
virtual size_t select_dim_size(size_t dim);
// pull value各维度相加总size
virtual size_t select_size();
// push value维度
virtual size_t update_dim();
// push value各个维度的size
virtual size_t update_dim_size(size_t dim);
// push value各维度相加总size
virtual size_t update_size();
// 判断该value是否进行shrink
virtual bool shrink(float* value);
// 判断该value是否保存到ssd
// virtual bool save_ssd(float* value);
virtual bool need_extend_mf(float* value);
virtual bool has_mf(size_t size);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 2, save xbox base feature
bool save(float* value, int param) override;
// update delta_score and unseen_days after save
void update_stat_after_save(float* value, int param) override;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual int32_t create(float** value, size_t num);
// 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values,
size_t num);
// 将update_values聚合到一起
virtual int32_t merge(float** update_values,
const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values,
size_t num);
std::string parse_to_string(const float* value, int param) override;
int32_t parse_from_string(const std::string& str, float* v) override;
virtual bool create_value(int type, const float* value);
// 这个接口目前只用来取show
float get_field(float* value, const std::string& name) override {
// CHECK(name == "show");
if (name == "show") {
return common_feature_value.show(value);
}
return 0.0;
}
private:
// float show_click_score(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
// CtrCommonFeatureValue common_feature_value;
float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold;
public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
CtrCommonFeatureValue common_feature_value;
float show_click_score(float show, float click);
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
};
} // namespace distributed
} // namespace paddle
...@@ -26,3 +26,6 @@ cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost ...@@ -26,3 +26,6 @@ cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost
set_source_files_properties(sparse_sgd_rule_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_sgd_rule_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS ${COMMON_DEPS} boost table) cc_test(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS ${COMMON_DEPS} boost table)
set_source_files_properties(ctr_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost table)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/table/ctr_accessor.h"
#include <cmath>
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
TableAccessorParameter gen_param() {
TableAccessorParameter param;
param.set_accessor_class("CtrCommonAccessor");
param.set_fea_dim(11);
param.set_embedx_dim(8);
param.mutable_ctr_accessor_param()->set_nonclk_coeff(0.2);
param.mutable_ctr_accessor_param()->set_click_coeff(1);
param.mutable_ctr_accessor_param()->set_base_threshold(0.5);
param.mutable_ctr_accessor_param()->set_delta_threshold(0.2);
param.mutable_ctr_accessor_param()->set_delta_keep_days(16);
param.mutable_ctr_accessor_param()->set_show_click_decay_rate(0.99);
/*
param.mutable_embed_sgd_param()->set_name("naive");
auto* naive_param = param.mutable_embed_sgd_param()->mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
*/
param.mutable_embed_sgd_param()->set_name("StdAdaGradSGDRule");
auto* adagrad_param = param.mutable_embed_sgd_param()->mutable_adagrad();
adagrad_param->set_learning_rate(0.1);
adagrad_param->set_initial_range(0.3);
adagrad_param->set_initial_g2sum(0.0);
adagrad_param->add_weight_bounds(-10.0);
adagrad_param->add_weight_bounds(10.0);
param.mutable_embedx_sgd_param()->set_name("SparseNaiveSGDRule");
auto* naive_param = param.mutable_embedx_sgd_param()->mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
return std::move(param);
}
TEST(downpour_feature_value_accessor_test, test_shrink) {
TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0);
VLOG(3) << "size of struct: " << acc->common_feature_value.embed_sgd_dim
<< " " << acc->common_feature_value.embedx_dim << " "
<< acc->common_feature_value.embedx_sgd_dim << " "
<< acc->common_feature_value.dim() << "\n";
float* value = new float[acc->dim()];
for (auto i = 0u; i < acc->dim(); ++i) {
value[i] = i * 1.0;
}
ASSERT_TRUE(!acc->shrink(value));
// set unseen_days too long
value[1] = 1000;
// set delta score too small
value[2] = 0.001;
ASSERT_TRUE(acc->shrink(value));
}
TEST(downpour_feature_value_accessor_test, test_save) {
TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0);
float* value = new float[acc->dim()];
for (auto i = 0u; i < acc->dim(); ++i) {
value[i] = i * 1.0;
}
// save all feature
ASSERT_TRUE(acc->save(value, 0));
// save delta feature
ASSERT_TRUE(acc->save(value, 1));
// save base feature with time decay
ASSERT_TRUE(acc->save(value, 2));
VLOG(3) << "test_save:";
for (auto i = 0u; i < acc->dim(); ++i) {
VLOG(3) << value[i];
}
}
TEST(downpour_feature_value_accessor_test, test_create) {
TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0);
const int field_size = 7 + 8;
const int item_size = 10;
float** value = new float*[item_size];
for (auto i = 0u; i < item_size; ++i) {
value[i] = new float[field_size];
}
ASSERT_EQ(acc->create(value, item_size), 0);
for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < field_size; ++j) {
VLOG(3) << value[i][j] << " ";
// ASSERT_FLOAT_EQ(value[i][j], 0);
}
VLOG(3) << "\n";
}
}
TEST(downpour_feature_value_accessor_test, test_update) {
TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0);
VLOG(3) << "dim: " << acc->common_feature_value.dim() << "\n";
VLOG(3) << "update_dim: " << acc->update_dim() << "\n";
const int field_size = 7 + 8;
const int item_size = 10;
float** value = new float*[item_size];
for (auto i = 0u; i < item_size; ++i) {
value[i] = new float[field_size];
for (auto j = 0u; j < field_size; ++j) {
value[i][j] = 0;
}
}
typedef const float* const_float_ptr;
const_float_ptr* grad = new const_float_ptr[item_size];
for (auto i = 0u; i < item_size; ++i) {
float* p = new float[acc->update_dim()];
for (auto j = 0u; j < acc->update_dim(); ++j) {
p[j] = i;
}
grad[i] = p;
}
struct DownpourSparseValueTest {
float slot;
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
std::vector<float> embed_g2sum;
std::vector<float> embedx_w;
std::vector<float> embedx_g2sum;
void to_array(float* ptr, size_t dim) {
ptr[0] = slot;
ptr[1] = unseen_days;
ptr[2] = delta_score;
ptr[3] = show;
ptr[4] = click;
ptr[5] = embed_w;
int idx = 6;
for (auto j = 0u; j < 1; ++j) {
ptr[idx + j] = embed_g2sum[j];
}
idx += 1;
for (auto j = 0u; j < 8; ++j) {
ptr[idx + j] = embedx_w[j];
}
idx += 8;
for (auto j = 0u; j < 0; ++j) {
ptr[idx + j] = embedx_g2sum[j];
}
}
};
struct DownpourSparsePushValueTest {
float slot;
float show;
float click;
float embed_g;
std::vector<float> embedx_g;
};
std::vector<float*> exp_value;
for (auto i = 0u; i < item_size; ++i) {
DownpourSparseValueTest v;
v.slot = value[i][0];
v.unseen_days = value[i][1];
v.delta_score = value[i][2];
v.show = value[i][3];
v.click = value[i][4];
v.embed_w = value[i][5];
int idx = 6;
for (auto j = 0u; j < acc->common_feature_value.embed_sgd_dim; ++j) {
v.embed_g2sum.push_back(value[i][idx + j]);
}
idx += acc->common_feature_value.embed_sgd_dim;
for (auto j = 0u; j < acc->common_feature_value.embedx_dim; ++j) {
v.embedx_w.push_back(value[i][idx + j]);
}
idx += acc->common_feature_value.embedx_dim;
for (auto j = 0u; j < acc->common_feature_value.embedx_sgd_dim; ++j) {
v.embedx_g2sum.push_back(value[i][idx + j]);
}
DownpourSparsePushValueTest push_v;
push_v.slot = grad[i][0];
push_v.show = grad[i][1];
push_v.click = grad[i][2];
push_v.embed_g = grad[i][3];
for (auto j = 0; j < parameter.embedx_dim(); ++j) {
push_v.embedx_g.push_back(grad[i][4 + j]);
}
v.slot = push_v.slot;
v.unseen_days = 0;
v.show += push_v.show;
v.click += push_v.click;
v.delta_score += acc->show_click_score(push_v.show, push_v.click);
acc->_embed_sgd_rule->update_value(&v.embed_w, &v.embed_g2sum[0],
&push_v.embed_g);
acc->_embedx_sgd_rule->update_value(&v.embedx_w[0], &v.embedx_g2sum[0],
&push_v.embedx_g[0]);
float* ptr = new float[acc->dim()];
v.to_array(ptr, parameter.embedx_dim());
exp_value.push_back(ptr);
}
acc->update(value, grad, item_size);
for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < acc->dim(); ++j) {
VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " ";
ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]);
}
}
}
TEST(downpour_feature_value_accessor_test, test_show_click_score) {
TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0);
float show = 10;
float click = 6;
ASSERT_FLOAT_EQ(acc->show_click_score(show, click), 6.8);
}
TEST(downpour_feature_value_accessor_test, test_string_related) {
TableAccessorParameter parameter = gen_param();
CtrCommonAccessor* acc = new CtrCommonAccessor();
ASSERT_EQ(acc->configure(parameter), 0);
ASSERT_EQ(acc->initialize(), 0);
const int field_size = 15;
float* value = new float[field_size];
for (auto i = 0u; i < field_size; ++i) {
value[i] = i;
}
auto str = acc->parse_to_string(value, 0);
VLOG(3) << str << std::endl;
str = "0 1 2 3 4 5 6";
ASSERT_NE(acc->parse_from_string(str, value), 0);
// make sure init_zero=true
for (auto i = 7; i < 15; ++i) {
ASSERT_FLOAT_EQ(value[i], 0);
}
}
} // namespace distributed
} // namespace paddle
...@@ -24,26 +24,6 @@ ...@@ -24,26 +24,6 @@
namespace paddle { namespace paddle {
namespace string { namespace string {
inline size_t count_spaces(const char* s) {
size_t count = 0;
while (*s != 0 && isspace(*s++)) {
count++;
}
return count;
}
inline size_t count_nonspaces(const char* s) {
size_t count = 0;
while (*s != 0 && !isspace(*s++)) {
count++;
}
return count;
}
// remove leading and tailing spaces // remove leading and tailing spaces
std::string trim_spaces(const std::string& str) { std::string trim_spaces(const std::string& str) {
const char* p = str.c_str(); const char* p = str.c_str();
...@@ -74,20 +54,6 @@ std::string erase_spaces(const std::string& str) { ...@@ -74,20 +54,6 @@ std::string erase_spaces(const std::string& str) {
return result; return result;
} }
inline int str_to_float(const char* str, float* v) {
const char* head = str;
char* cursor = NULL;
int index = 0;
while (*(head += count_spaces(head)) != 0) {
v[index++] = std::strtof(head, &cursor);
if (head == cursor) {
break;
}
head = cursor;
}
return index;
}
bool ends_with(std::string const& input, std::string const& test) { bool ends_with(std::string const& input, std::string const& test) {
if (test.size() > input.size()) return false; if (test.size() > input.size()) return false;
return std::equal(test.rbegin(), test.rend(), input.rbegin()); return std::equal(test.rbegin(), test.rend(), input.rbegin());
......
...@@ -26,9 +26,25 @@ ...@@ -26,9 +26,25 @@
namespace paddle { namespace paddle {
namespace string { namespace string {
inline size_t count_spaces(const char* s); inline size_t count_spaces(const char* s) {
size_t count = 0;
inline size_t count_nonspaces(const char* s); while (*s != 0 && isspace(*s++)) {
count++;
}
return count;
}
inline size_t count_nonspaces(const char* s) {
size_t count = 0;
while (*s != 0 && !isspace(*s++)) {
count++;
}
return count;
}
template <class... ARGS> template <class... ARGS>
void format_string_append(std::string& str, const char* fmt, // NOLINT void format_string_append(std::string& str, const char* fmt, // NOLINT
...@@ -67,7 +83,19 @@ std::string trim_spaces(const std::string& str); ...@@ -67,7 +83,19 @@ std::string trim_spaces(const std::string& str);
// erase all spaces in str // erase all spaces in str
std::string erase_spaces(const std::string& str); std::string erase_spaces(const std::string& str);
int str_to_float(const char* str, float* v); inline int str_to_float(const char* str, float* v) {
const char* head = str;
char* cursor = NULL;
int index = 0;
while (*(head += count_spaces(head)) != 0) {
v[index++] = std::strtof(head, &cursor);
if (head == cursor) {
break;
}
head = cursor;
}
return index;
}
// checks whether the test string is a suffix of the input string. // checks whether the test string is a suffix of the input string.
bool ends_with(std::string const& input, std::string const& test); bool ends_with(std::string const& input, std::string const& test);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册