未验证 提交 148582fe 编写于 作者: D danleifeng 提交者: GitHub

【GPUPS】add ctr_dymf_accessor for pscore (#42827)

上级 7a171e3c
......@@ -35,12 +35,13 @@ set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRI
set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ctr_dymf_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc ctr_dymf_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(sparse_table SRCS memory_sparse_table.cc ssd_sparse_table.cc memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table rocksdb)
cc_library(table SRCS table.cc DEPS sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int CtrDymfAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
VLOG(0) << " INTO CtrDymfAccessor::Initialize()";
InitAccessorInfo();
return 0;
}
void CtrDymfAccessor::InitAccessorInfo() {
_accessor_info.dim = common_feature_value.Dim();
_accessor_info.size = common_feature_value.Size();
auto embedx_dim = _config.embedx_dim();
VLOG(0) << "InitAccessorInfo embedx_dim:" << embedx_dim;
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
}
bool CtrDymfAccessor::Shrink(float* value) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
common_feature_value.Show(value) *= _show_click_decay_rate;
common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after
auto score = ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value));
auto unseen_days = common_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool CtrDymfAccessor::SaveCache(float* value, int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
return common_feature_value.Show(value) > global_cache_threshold;
}
return false;
}
bool CtrDymfAccessor::SaveSSD(float* value) {
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrDymfAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
common_feature_value.DeltaScore(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// common_feature_value.UnseenDays(value)++;
return true;
}
// save revert batch_model
case 5: {
return true;
}
default:
return true;
}
}
void CtrDymfAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.DeltaScore(value) = 0;
}
}
return;
case 3: {
common_feature_value.UnseenDays(value)++;
}
return;
default:
return;
}
}
int32_t CtrDymfAccessor::Create(float** values, size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[common_feature_value.UnseenDaysIndex()] = 0;
value[common_feature_value.DeltaScoreIndex()] = 0;
value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
value[common_feature_value.MfDimIndex()] = -1;
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
}
return 0;
}
bool CtrDymfAccessor::NeedExtendMF(float* value) {
float show = value[common_feature_value.ShowIndex()];
float click = value[common_feature_value.ClickIndex()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
bool CtrDymfAccessor::HasMF(size_t size) {
return size > common_feature_value.EmbedxG2SumIndex();
}
// from CommonFeatureValue to CtrDymfPullValue
int32_t CtrDymfAccessor::Select(float** select_values, const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
const float* value = values[value_item];
select_value[CtrDymfPullValue::ShowIndex()] =
value[common_feature_value.ShowIndex()];
select_value[CtrDymfPullValue::ClickIndex()] =
value[common_feature_value.ClickIndex()];
select_value[CtrDymfPullValue::EmbedWIndex()] =
value[common_feature_value.EmbedWIndex()];
memcpy(select_value + CtrDymfPullValue::EmbedxWIndex(),
value + common_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
}
// from CtrDymfPushValue to CtrDymfPushValue
// first dim: item
// second dim: field num
int32_t CtrDymfAccessor::Merge(float** update_values,
const float** other_update_values, size_t num) {
// currently merge in cpu is not supported
return 0;
}
// from CtrDymfPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
int32_t CtrDymfAccessor::Update(float** update_values,
const float** push_values, size_t num) {
// currently update in cpu is not supported
return 0;
}
bool CtrDymfAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrDymfPushValue::Show(const_cast<float*>(value));
auto click = CtrDymfPushValue::Click(const_cast<float*>(value));
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
float CtrDymfAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string CtrDymfAccessor::ParseToString(const float* v, int param) {
/*
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
std::vector<float> embed_g2sum; // float embed_g2sum
float slot;
float mf_dim;
std::<vector>float embedx_g2sum; // float embedx_g2sum
std::vector<float> embedx_w;
*/
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4];
// << v[5] << " " << v[6];
for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.EmbedxWIndex(); i++) {
os << " " << v[i];
}
os << " " << common_feature_value.Slot(const_cast<float*>(v)) << " "
<< common_feature_value.MfDim(const_cast<float*>(v));
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxG2SumIndex()) {
VLOG(0) << "common_feature_value.EmbedxG2SumIndex():"
<< common_feature_value.EmbedxG2SumIndex();
for (auto i = common_feature_value.EmbedxG2SumIndex();
i < common_feature_value.Dim(); ++i) {
os << " " << v[i];
}
}
return os.str();
}
int CtrDymfAccessor::ParseFromString(const std::string& str, float* value) {
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 7) << "expect more than 7 real:" << ret;
return ret;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
// DownpourUnitAccessor
class CtrDymfAccessor : public ValueAccessor {
public:
struct CtrDymfFeatureValue {
/*
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
// float embed_g2sum;
std::vector<float> embed_g2sum;
float slot;
float mf_dim
std::<vector>float embedx_g2sum;
// float embedx_g2sum;
std::vector<float> embedx_w;
*/
int Dim() { return 7 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); }
int UnseenDaysIndex() { return 0; }
int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; }
int EmbedWIndex() { return ClickIndex() + 1; }
int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int SlotIndex() { return EmbedG2SumIndex() + 1; }
int MfDimIndex() { return SlotIndex() + 1; }
int EmbedxG2SumIndex() { return MfDimIndex() + 1; }
int EmbedxWIndex() { return EmbedxG2SumIndex() + 1; }
float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; }
float& MfDim(float* val) { return val[MfDimIndex()]; }
float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
};
struct CtrDymfPushValue {
/*
float slot;
float show;
float click;
float mf_dim;
float embed_g;
std::vector<float> embedx_g;
*/
static int Dim(int embedx_dim) { return 5 + embedx_dim; }
static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int SlotIndex() { return 0; }
static int ShowIndex() { return CtrDymfPushValue::SlotIndex() + 1; }
static int ClickIndex() { return CtrDymfPushValue::ShowIndex() + 1; }
static int MfDimIndex() { return CtrDymfPushValue::ClickIndex() + 1; }
static int EmbedGIndex() { return CtrDymfPushValue::MfDimIndex() + 1; }
static int EmbedxGIndex() { return CtrDymfPushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) {
return val[CtrDymfPushValue::SlotIndex()];
}
static float& Show(float* val) {
return val[CtrDymfPushValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrDymfPushValue::ClickIndex()];
}
static float& MfDim(float* val) {
return val[CtrDymfPushValue::MfDimIndex()];
}
static float& EmbedG(float* val) {
return val[CtrDymfPushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + CtrDymfPushValue::EmbedxGIndex();
}
};
struct CtrDymfPullValue {
/*
float show;
float click;
float mf_dim;
float embed_w;
std::vector<float> embedx_w;
*/
static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int DimSize(size_t dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int MfDimIndex() { return 2; }
static int EmbedWIndex() { return 3; }
static int EmbedxWIndex() { return 4; }
static float& Show(float* val) {
return val[CtrDymfPullValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrDymfPullValue::ClickIndex()];
}
static float& MfDim(float* val) {
return val[CtrDymfPullValue::MfDimIndex()];
}
static float& EmbedW(float* val) {
return val[CtrDymfPullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + CtrDymfPullValue::EmbedxWIndex();
}
};
CtrDymfAccessor() {}
virtual ~CtrDymfAccessor() {}
virtual int Initialize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
// 判断该value是否保存到ssd
// virtual bool save_ssd(float* value);
virtual bool NeedExtendMF(float* value);
virtual bool HasMF(size_t size);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value, int param,
double global_cache_threshold) override;
bool SaveSSD(float* value) override;
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中
virtual int32_t Select(float** select_values, const float** values,
size_t num);
// 将update_values聚合到一起
virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t Update(float** values, const float** update_values,
size_t num);
std::string ParseToString(const float* value, int param) override;
int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool CreateValue(int type, const float* value);
// 这个接口目前只用来取show
float GetField(float* value, const std::string& name) override {
// CHECK(name == "show");
if (name == "show") {
return common_feature_value.Show(value);
}
return 0.0;
}
private:
// float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
// CtrDymfFeatureValue common_feature_value;
float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold;
bool _show_scale = false;
public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
CtrDymfFeatureValue common_feature_value;
float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
};
} // namespace distributed
} // namespace paddle
......@@ -22,6 +22,7 @@
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_double_accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
......@@ -40,9 +41,11 @@ REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseTable);
REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrDoubleAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrDymfAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, SparseAccessor);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
......
......@@ -35,6 +35,9 @@ cc_test(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS ${COMMON_DEPS} bo
set_source_files_properties(ctr_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost table)
set_source_files_properties(ctr_dymf_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(ctr_dymf_accessor_test SRCS ctr_dymf_accessor_test.cc DEPS ${COMMON_DEPS} boost table)
set_source_files_properties(memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS ${COMMON_DEPS} boost table)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include <cmath>
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
TableAccessorParameter gen_param() {
TableAccessorParameter param;
param.set_accessor_class("CtrDymfAccessor");
param.set_fea_dim(11);
param.set_embedx_dim(8);
param.mutable_ctr_accessor_param()->set_nonclk_coeff(0.2);
param.mutable_ctr_accessor_param()->set_click_coeff(1);
param.mutable_ctr_accessor_param()->set_base_threshold(0.5);
param.mutable_ctr_accessor_param()->set_delta_threshold(0.2);
param.mutable_ctr_accessor_param()->set_delta_keep_days(16);
param.mutable_ctr_accessor_param()->set_show_click_decay_rate(0.99);
/*
param.mutable_embed_sgd_param()->set_name("naive");
auto* naive_param = param.mutable_embed_sgd_param()->mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
*/
param.mutable_embed_sgd_param()->set_name("StdAdaGradSGDRule");
auto* adagrad_param = param.mutable_embed_sgd_param()->mutable_adagrad();
adagrad_param->set_learning_rate(0.1);
adagrad_param->set_initial_range(0.3);
adagrad_param->set_initial_g2sum(0.0);
adagrad_param->add_weight_bounds(-10.0);
adagrad_param->add_weight_bounds(10.0);
param.mutable_embedx_sgd_param()->set_name("SparseNaiveSGDRule");
auto* naive_param = param.mutable_embedx_sgd_param()->mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
return param;
}
TEST(downpour_feature_value_accessor_test, test_shrink) {
TableAccessorParameter parameter = gen_param();
CtrDymfAccessor* acc = new CtrDymfAccessor();
ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0);
VLOG(3) << "size of struct: " << acc->common_feature_value.embed_sgd_dim
<< " " << acc->common_feature_value.embedx_dim << " "
<< acc->common_feature_value.embedx_sgd_dim << " "
<< acc->common_feature_value.Dim() << "\n";
float* value = new float[acc->GetAccessorInfo().dim];
for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
value[i] = i * 1.0;
}
ASSERT_TRUE(!acc->Shrink(value));
// set unseen_days too long
value[0] = 1000;
// set delta score too small
value[1] = 0.001;
ASSERT_TRUE(acc->Shrink(value));
}
TEST(downpour_feature_value_accessor_test, test_save) {
TableAccessorParameter parameter = gen_param();
CtrDymfAccessor* acc = new CtrDymfAccessor();
ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0);
float* value = new float[acc->GetAccessorInfo().dim];
for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
value[i] = i * 1.0;
}
// save all feature
ASSERT_TRUE(acc->Save(value, 0));
// save delta feature
ASSERT_TRUE(acc->Save(value, 1));
// save base feature with time decay
ASSERT_TRUE(acc->Save(value, 2));
VLOG(3) << "test_save:";
for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) {
VLOG(3) << value[i];
}
}
TEST(downpour_feature_value_accessor_test, test_create) {
TableAccessorParameter parameter = gen_param();
CtrDymfAccessor* acc = new CtrDymfAccessor();
ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0);
const int field_size = 8 + 8;
const int item_size = 10;
float** value = new float*[item_size];
for (auto i = 0u; i < item_size; ++i) {
value[i] = new float[field_size];
}
ASSERT_EQ(acc->Create(value, item_size), 0);
for (auto i = 0u; i < item_size; ++i) {
for (auto j = 0u; j < field_size; ++j) {
VLOG(3) << value[i][j] << " ";
// ASSERT_FLOAT_EQ(value[i][j], 0);
}
VLOG(3) << "\n";
}
}
TEST(downpour_feature_value_accessor_test, test_show_click_score) {
TableAccessorParameter parameter = gen_param();
CtrDymfAccessor* acc = new CtrDymfAccessor();
ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0);
float show = 10;
float click = 6;
ASSERT_FLOAT_EQ(acc->ShowClickScore(show, click), 6.8);
}
TEST(downpour_feature_value_accessor_test, test_string_related) {
TableAccessorParameter parameter = gen_param();
CtrDymfAccessor* acc = new CtrDymfAccessor();
ASSERT_EQ(acc->Configure(parameter), 0);
ASSERT_EQ(acc->Initialize(), 0);
const int field_size = 16;
float* value = new float[field_size];
for (auto i = 0u; i < field_size; ++i) {
value[i] = i;
}
auto str = acc->ParseToString(value, 0);
VLOG(0) << "test_string_related" << str << std::endl;
str = "0 1 2 3 4 5 6 7";
ASSERT_NE(acc->ParseFromString(str, value), 0);
// make sure init_zero=true
}
} // namespace distributed
} // namespace paddle
......@@ -534,7 +534,7 @@ class DistributedStrategy(object):
support_sparse_accessor_class = [
'DownpourSparseValueAccessor', 'DownpourCtrAccessor',
'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor',
'DownpourDoubleUnitAccessor'
'DownpourDoubleUnitAccessor', 'DownpourCtrDymfAccessor'
]
from google.protobuf.descriptor import FieldDescriptor
table_param = self.strategy.downpour_table_param
......@@ -616,6 +616,8 @@ class DistributedStrategy(object):
if accessor_class.find("Double") >= 0:
table_data.accessor.accessor_class = 'CtrDoubleAccessor'
elif accessor_class.find("Dymf") >= 0:
table_data.accessor.accessor_class = 'CtrDymfAccessor'
else:
table_data.accessor.accessor_class = 'CtrCommonAccessor'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册