未验证 提交 d64f7b3b 编写于 作者: Z zhaocaibei123 提交者: GitHub

add ctr table depends (#36465)

* add ctr table depends

* code style

* fix

* fix

* fix naming

* rename

* rename
上级 72533986
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <time.h>
#include <atomic>
#include <random>
namespace paddle {
namespace distributed {
// Get time in seconds.
inline double current_realtime() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
return tp.tv_sec + tp.tv_nsec * 1e-9;
}
inline std::default_random_engine& local_random_engine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
static std::atomic<unsigned long> x(0); // NOLINT
std::seed_seq sseq = {
x++, x++, x++, (unsigned long)(current_realtime() * 1000)}; // NOLINT
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
template <class T = double>
std::uniform_real_distribution<T>& local_uniform_real_distribution() {
thread_local std::uniform_real_distribution<T> distr;
assert(distr.a() == 0.0 && distr.b() == 1.0);
return distr;
}
template <class T = double>
T uniform_real() {
return local_uniform_real_distribution<T>()(local_random_engine());
}
template <class T = double>
T uniform_real(T a, T b) {
if (a == b) {
return a;
}
return (T)(a + uniform_real<T>() * (b - a));
}
} // namespace distributed
} // namespace paddle
......@@ -119,10 +119,41 @@ message TableParameter {
message TableAccessorParameter {
optional string accessor_class = 1;
// optional SparseSGDRuleParameter sparse_sgd_param = 2;
optional uint32 fea_dim = 4 [ default = 11 ];
optional uint32 embedx_dim = 5 [ default = 8 ];
optional uint32 embedx_threshold = 6 [ default = 10 ];
optional CtrAccessorParameter ctr_accessor_param = 7;
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
// optional SparseCommonSGDRuleParameter sparse_commonsgd_param = 9;
optional SparseCommonSGDRuleParameter embed_sgd_param = 10;
optional SparseCommonSGDRuleParameter embedx_sgd_param = 11;
}
message CtrAccessorParameter {
optional float nonclk_coeff = 1
[ default = 0.1 ]; // to calculate show_click_score
optional float click_coeff = 2
[ default = 1 ]; // to calculate show_click_score
optional float base_threshold = 3 [
default = 1.5
]; // show_click_score > base_threshold, this feature can be saved
optional float delta_threshold = 4
[ default =
0.25 ]; // delta_score > delta_threshold, this feature can be saved
optional float delta_keep_days = 5
[ default =
16 ]; // unseen_day < delta_keep_days, this feature can be saved
optional float show_click_decay_rate = 6 [
default = 0.98
]; // show/click will update to show/click * show_click_decay_rate after a day
optional float delete_threshold = 7
[ default = 0.8 ]; // threshold to shrink a feasign
optional float delete_after_unseen_days = 8
[ default = 30 ]; // unseen_day > delete_after_unseen_days, this feature
// will be delete in shrink_model
optional int32 ssd_unseenday_threshold = 9
[ default = 1 ]; // threshold to save ssd
}
message TensorAccessorParameter {
......@@ -150,3 +181,40 @@ message TableAccessorSaveParameter {
optional string converter = 2;
optional string deconverter = 3;
}
// message SparseSGDRuleParameter {
// optional double learning_rate = 1 [default = 0.05];
// optional double initial_g2sum = 2 [default = 3.0];
// optional double initial_range = 3 [default = 0.0001];
// repeated float weight_bounds = 4;
//}
message SparseCommonSGDRuleParameter {
optional string name = 1;
optional SparseNaiveSGDRuleParameter naive = 2;
optional SparseAdagradSGDRuleParameter adagrad = 3;
optional SparseAdamSGDParameter adam = 4;
}
message SparseNaiveSGDRuleParameter { // SparseNaiveSGDRule
optional double learning_rate = 1 [ default = 0.05 ];
optional double initial_range = 2 [ default = 0.0001 ];
repeated float weight_bounds = 3;
}
message
SparseAdagradSGDRuleParameter { // SparseAdaGradSGDRule|StdAdaGradSGDRule
optional double learning_rate = 1 [ default = 0.05 ];
optional double initial_g2sum = 2 [ default = 3.0 ];
optional double initial_range = 3 [ default = 0.0001 ];
repeated float weight_bounds = 4;
}
message SparseAdamSGDParameter { // SparseAdamSGDRule
optional double learning_rate = 1 [ default = 0.001 ];
optional double initial_range = 2 [ default = 0.0001 ];
optional double beta1_decay_rate = 3 [ default = 0.9 ];
optional double beta2_decay_rate = 4 [ default = 0.999 ];
optional double ada_epsilon = 5 [ default = 1e-08 ];
repeated float weight_bounds = 6;
}
......@@ -35,4 +35,8 @@ cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_
cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto executor scope device_context tensor ${TABLE_DEPS})
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost sparse_sgd_rule)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/thirdparty/round_robin.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
static const int CTR_SPARSE_SHARD_BUCKET_NUM_BITS = 6;
static const size_t CTR_SPARSE_SHARD_BUCKET_NUM =
static_cast<size_t>(1) << CTR_SPARSE_SHARD_BUCKET_NUM_BITS;
class FixedFeatureValue {
public:
FixedFeatureValue() {}
~FixedFeatureValue() {}
float *data() { return data_.data(); }
size_t size() { return data_.size(); }
void resize(size_t size) { data_.resize(size); }
void shrink_to_fit() { data_.shrink_to_fit(); }
private:
std::vector<float> data_;
};
class SparseTableShard {
public:
typedef typename robin_hood::unordered_map<uint64_t, FixedFeatureValue *>
map_type;
SparseTableShard() {}
~SparseTableShard() {}
FixedFeatureValue *Init(const uint64_t &id) {
size_t hash = hasher_(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
FixedFeatureValue *value = nullptr;
value = butil::get_object<FixedFeatureValue>();
table[id] = value;
return value;
}
// dont judge if (has(id))
float *Get(const uint64_t &id) {
size_t hash = hasher_(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
// auto &value = table.at(id);
// return value->data_.data();
auto res = table.find(id);
FixedFeatureValue *value = res->second;
return value->data();
}
// for load, to reset count, unseen_days
FixedFeatureValue *GetValue(const uint64_t &id) {
size_t hash = hasher_(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto res = table.find(id);
return res->second;
}
void erase(uint64_t feasign) {
size_t hash = hasher_(feasign);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto iter = table.find(feasign);
if (iter != table.end()) {
butil::return_object(iter->second);
iter = table.erase(iter);
}
}
void clear() {}
size_t compute_bucket(size_t hash) {
if (CTR_SPARSE_SHARD_BUCKET_NUM == 1) {
return 0;
} else {
return hash >> (sizeof(size_t) * 8 - CTR_SPARSE_SHARD_BUCKET_NUM_BITS);
}
}
map_type::iterator end() {
return values_[CTR_SPARSE_SHARD_BUCKET_NUM - 1].end();
}
map_type::iterator Find(uint64_t id) {
size_t hash = hasher_(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto got = table.find(id);
if (got == table.end()) {
return end();
} else {
return got;
}
}
private:
bool Has(const uint64_t id) {
size_t hash = hasher_(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto got = table.find(id);
if (got == table.end()) {
return false;
} else {
return true;
}
}
public:
map_type values_[CTR_SPARSE_SHARD_BUCKET_NUM];
std::hash<uint64_t> hasher_;
};
} // namespace distributed
} // namespace paddle
......@@ -31,8 +31,9 @@ struct PullSparseValue {
feasigns_(nullptr),
frequencies_(nullptr) {}
explicit PullSparseValue(std::vector<uint64_t> feasigns,
std::vector<uint32_t> frequencies, int dim) {
explicit PullSparseValue(std::vector<uint64_t>& feasigns, // NOLINT
std::vector<uint32_t>& frequencies, // NOLINT
int dim) {
numel_ = feasigns.size();
dim_ = dim;
is_training_ = true;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/sparse_sgd_rule.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
namespace paddle {
namespace distributed {
void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto naive_param = param.naive();
learning_rate_ = naive_param.learning_rate();
_initial_range = naive_param.initial_range();
if (naive_param.weight_bounds_size() == 0) {
_min_bound = -std::numeric_limits<float>::max();
_max_bound = std::numeric_limits<float>::max();
} else {
CHECK(naive_param.weight_bounds_size() >= 2)
<< "invalid repeated size for weight_bounds:"
<< naive_param.weight_bounds_size();
_min_bound = naive_param.weight_bounds(0);
_max_bound = naive_param.weight_bounds(1);
}
}
void SparseNaiveSGDRule::update_value_work(float* w, float* sgd,
const float* push_value,
float scale) {
for (size_t i = 0; i < _embedding_dim; ++i) {
w[i] -= learning_rate_ * push_value[i];
bound_value(w[i]);
}
}
void SparseNaiveSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
if (zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) {
value[i] = 0;
}
} else {
for (size_t i = 0; i < _embedding_dim; ++i) {
value[i] =
(local_uniform_real_distribution<float>()(local_random_engine()) * 2 -
1) *
_initial_range;
bound_value(value[i]);
}
}
}
void SparseAdaGradSGDRule::load_config(
const SparseCommonSGDRuleParameter& param, size_t emb_dim) {
_embedding_dim = emb_dim;
auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate();
_initial_g2sum = adagrad_param.initial_g2sum();
_initial_range = adagrad_param.initial_range();
if (adagrad_param.weight_bounds_size() == 0) {
_min_bound = -std::numeric_limits<float>::max();
_max_bound = std::numeric_limits<float>::max();
} else {
CHECK(adagrad_param.weight_bounds_size() >= 2)
<< "invalid repeated size for weight_bounds:"
<< adagrad_param.weight_bounds_size();
_min_bound = adagrad_param.weight_bounds(0);
_max_bound = adagrad_param.weight_bounds(1);
}
}
void SparseAdaGradSGDRule::update_value_work(float* w, float* sgd,
const float* grad, float scale) {
float& g2sum = sgd[g2sum_index()];
double add_g2sum = 0;
for (int i = 0; i < _embedding_dim; i++) {
double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
bound_value(w[i]);
add_g2sum += scaled_grad * scaled_grad;
}
g2sum += add_g2sum / _embedding_dim;
}
void SparseAdaGradSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
bound_value(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
bound_value(value[i]);
}
}
sgd[g2sum_index()] = 0;
}
void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate();
_initial_g2sum = adagrad_param.initial_g2sum();
_initial_range = adagrad_param.initial_range();
if (adagrad_param.weight_bounds_size() == 0) {
_min_bound = -std::numeric_limits<float>::max();
_max_bound = std::numeric_limits<float>::max();
} else {
CHECK(adagrad_param.weight_bounds_size() >= 2)
<< "invalid repeated size for weight_bounds:"
<< adagrad_param.weight_bounds_size();
_min_bound = adagrad_param.weight_bounds(0);
_max_bound = adagrad_param.weight_bounds(1);
}
}
void StdAdaGradSGDRule::update_value_work(float* w, float* sgd,
const float* grad, float scale) {
for (int i = 0; i < _embedding_dim; i++) {
float& g2sum = sgd[g2sum_index() + i];
double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
bound_value(w[i]);
g2sum += scaled_grad * scaled_grad;
}
}
void StdAdaGradSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
bound_value(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
bound_value(value[i]);
}
sgd[g2sum_index() + i] = 0;
}
}
void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto adam_param = param.adam();
learning_rate_ = adam_param.learning_rate();
_initial_range = adam_param.initial_range();
_beta1_decay_rate = adam_param.beta1_decay_rate();
_beta2_decay_rate = adam_param.beta2_decay_rate();
_ada_epsilon = adam_param.ada_epsilon();
if (adam_param.weight_bounds_size() == 0) {
_min_bound = -std::numeric_limits<float>::max();
_max_bound = std::numeric_limits<float>::max();
} else {
CHECK(adam_param.weight_bounds_size() >= 2)
<< "invalid repeated size for weight_bounds:"
<< adam_param.weight_bounds_size();
_min_bound = adam_param.weight_bounds(0);
_max_bound = adam_param.weight_bounds(1);
}
}
void SparseAdamSGDRule::update_value_work(float* w, float* sgd,
const float* grad, float scale) {
float* gsum = sgd + gsum_index();
float* g2sum = sgd + g2sum_index();
float* beta1_pow = sgd + beta1_pow_index();
float* beta2_pow = sgd + beta2_pow_index();
const float* g = grad;
float lr = learning_rate_;
float beta1_pow_ = *beta1_pow;
float beta2_pow_ = *beta2_pow;
// lr not change in one update
lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
for (int i = 0; i < _embedding_dim; i++) {
// Calculation
gsum[i] = _beta1_decay_rate * gsum[i] + (1 - _beta1_decay_rate) * g[i];
g2sum[i] =
_beta2_decay_rate * g2sum[i] + (1 - _beta2_decay_rate) * g[i] * g[i];
w[i] = w[i] - lr * (gsum[i] / (sqrt(g2sum[i]) + _ada_epsilon));
bound_value(w[i]);
}
// update beta_pow_decay
(*beta1_pow) *= _beta1_decay_rate;
(*beta2_pow) *= _beta2_decay_rate;
}
void SparseAdamSGDRule::init_value_work(float* value, float* sgd,
bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
bound_value(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
bound_value(value[i]);
}
}
// init rule gsum and g2sum
for (int i = gsum_index(); i < beta1_pow_index(); i++) {
sgd[i] = 0.0;
}
// init beta1_pow and beta2_pow
*(sgd + beta1_pow_index()) = _beta1_decay_rate;
*(sgd + beta2_pow_index()) = _beta2_decay_rate;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <thread>
#include <vector>
#include "glog/logging.h" // for CHECK
#include "paddle/fluid/distributed/common/local_random.h" // for local_uniform_real_distribution
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace paddle {
namespace distributed {
class SparseValueSGDRule {
public:
SparseValueSGDRule() {}
virtual ~SparseValueSGDRule() {}
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim) {
_embedding_dim = emb_dim;
_name = param.name();
}
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale) = 0;
virtual void init_value_work(float* value, float* sgd, bool zero_init) = 0;
virtual size_t dim() = 0;
const std::string& get_name() const { return _name; }
void init_value(float* value, float* sgd, bool zero_init = true) {
init_value_work(value, sgd, zero_init);
}
void update_value(float* w, float* sgd, const float* push_value,
float scale = 1) {
update_value_work(w, sgd, push_value, scale);
}
template <class T>
void bound_value(T& w) { // NOLINT
if (!(w >= _min_bound)) {
w = (T)_min_bound;
} else if (!(w <= _max_bound)) {
w = (T)_max_bound;
}
}
float& min_bound() { return _min_bound; }
float& max_bound() { return _max_bound; }
protected:
float _min_bound;
float _max_bound;
float _initial_range;
size_t _embedding_dim;
private:
std::string _name;
};
REGISTER_PSCORE_REGISTERER(SparseValueSGDRule);
class SparseNaiveSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return 0; }
private:
float learning_rate_;
};
class SparseAdaGradSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return 1; }
size_t g2sum_index() { return 0; }
private:
float learning_rate_;
float _initial_g2sum;
};
class StdAdaGradSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return _embedding_dim; }
size_t g2sum_index() { return 0; }
private:
float learning_rate_;
float _initial_g2sum;
};
class SparseAdamSGDRule : public SparseValueSGDRule {
public:
virtual void load_config(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void update_value_work(float* w, float* sgd, const float* push_value,
float scale);
virtual void init_value_work(float* value, float* sgd, bool zero_init);
virtual size_t dim() { return _embedding_dim * 2 + 2; }
size_t gsum_index() { return 0; }
size_t g2sum_index() { return gsum_index() + _embedding_dim; }
size_t beta1_pow_index() { return g2sum_index() + _embedding_dim; }
size_t beta2_pow_index() { return beta1_pow_index() + 1; }
protected:
float learning_rate_;
float _beta1_decay_rate;
float _beta2_decay_rate;
float _ada_epsilon;
};
} // namespace distributed
} // namespace paddle
......@@ -20,3 +20,9 @@ cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_funct
set_source_files_properties(graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
set_source_files_properties(sparse_sgd_rule_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS ${COMMON_DEPS} boost table)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/table/depends/feature_value.h"
namespace paddle {
namespace distributed {
TEST(BENCHMARK, LargeScaleKV) {
std::shared_ptr<SparseTableShard> shard =
std::make_shared<SparseTableShard>();
uint64_t key = 1;
auto itr = shard->Find(key);
ASSERT_TRUE(itr == shard->end());
std::vector<float> vec = {0.0, 0.1, 0.2, 0.3};
auto* feature_value = shard->Init(key);
feature_value->resize(vec.size());
memcpy(feature_value->data(), vec.data(), vec.size() * sizeof(float));
itr = shard->Find(key);
ASSERT_TRUE(itr != shard->end());
feature_value = itr->second;
float* value_data = feature_value->data();
ASSERT_FLOAT_EQ(value_data[0], 0.0);
ASSERT_FLOAT_EQ(value_data[1], 0.1);
ASSERT_FLOAT_EQ(value_data[2], 0.2);
ASSERT_FLOAT_EQ(value_data[3], 0.3);
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/table/sparse_sgd_rule.h"
#include <cmath>
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace paddle {
namespace distributed {
TEST(sparse_value_naive_sgd_test, init_and_update) {
SparseNaiveSGDRule rule;
SparseCommonSGDRuleParameter param;
param.set_name("naive");
auto* naive_param = param.mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
rule.load_config(param, 10);
// check init_value for zero
const int kItemSize = 10;
float w[kItemSize];
float grad[kItemSize];
rule.init_value(w, w + 9, true);
for (auto i = 0u; i < kItemSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0);
}
// check init_value for random
rule.init_value(w, w + 9, false);
for (auto i = 0u; i < kItemSize; ++i) {
ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound());
}
// check update_value for one field
for (auto i = 0u; i < kItemSize; ++i) {
w[i] = 0;
}
for (auto i = 0u; i < kItemSize; ++i) {
grad[i] = (i + 1) * 1.0;
}
float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, -0.500000,
-0.600000, -0.700000, -0.800000, -0.900000, -1.000000};
const float* ptr_grad = grad;
rule.update_value(w, w + 9, ptr_grad);
for (auto i = 0u; i < kItemSize; ++i) {
VLOG(3) << w[i] << "\n";
ASSERT_FLOAT_EQ(w[i], label[i]);
}
}
TEST(downpour_sparse_adagrad_test, test_init_and_update) {
SparseAdaGradSGDRule rule;
SparseCommonSGDRuleParameter param;
param.set_name("adagrad");
auto* adagrad_param = param.mutable_adagrad();
adagrad_param->set_learning_rate(0.1);
adagrad_param->set_initial_g2sum(0.2);
adagrad_param->set_initial_range(0.3);
adagrad_param->add_weight_bounds(-10.0);
adagrad_param->add_weight_bounds(10.0);
rule.load_config(param, 10);
// check init_value for zero
const int kValueSize = 11;
int kEmbSize = 10;
float w[kValueSize];
rule.init_value(w, w + 10, true);
for (auto i = 0u; i < kEmbSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0);
}
ASSERT_FLOAT_EQ(w[kEmbSize], 0);
// check init_value for random
rule.init_value(w, w + 10, false);
for (auto i = 0u; i < kEmbSize; ++i) {
ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound());
}
ASSERT_FLOAT_EQ(w[kEmbSize], 0);
// check update_value for one field
for (auto i = 0u; i < kEmbSize; ++i) {
w[i] = 0;
}
w[kEmbSize] = 0;
float grad[kEmbSize];
for (auto i = 0u; i < kEmbSize; ++i) {
grad[i] = (i + 1) * 1.0;
}
const float* ptr_grad = grad;
rule.update_value(w, w + 10, ptr_grad);
float label[] = {-0.100000, -0.200000, -0.300000, -0.400000,
-0.500000, -0.600000, -0.700000, -0.800000,
-0.900000, -1.000000, 38.500000};
for (auto i = 0u; i < kValueSize; ++i) {
ASSERT_FLOAT_EQ(w[i], label[i]);
}
}
TEST(downpour_sparse_adam_test, test_init_and_update) {
const int embed_dim = 10; // dims of parameters
SparseCommonSGDRuleParameter param;
param.set_name("adam");
auto* adam_param = param.mutable_adam();
adam_param->set_learning_rate(0.1);
adam_param->set_initial_range(0.3);
adam_param->set_beta1_decay_rate(0.9);
adam_param->set_beta2_decay_rate(0.999);
adam_param->set_ada_epsilon(1e-08);
adam_param->add_weight_bounds(-10.0);
adam_param->add_weight_bounds(10.0);
ASSERT_FLOAT_EQ(param.adam().learning_rate(), 0.1);
ASSERT_FLOAT_EQ(param.adam().initial_range(), 0.3);
ASSERT_FLOAT_EQ(param.adam().beta1_decay_rate(), 0.9);
ASSERT_FLOAT_EQ(param.adam().beta2_decay_rate(), 0.999);
ASSERT_FLOAT_EQ(param.adam().ada_epsilon(), 1e-08);
SparseAdamSGDRule rule;
rule.load_config(param, embed_dim);
// check init_value for zero
const int rule_dim =
rule.dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam
const int value_dim = embed_dim + rule_dim; // total dims of w + rule
float* value = new float[value_dim];
rule.init_value(value, value + embed_dim, true);
for (auto i = 0u; i < rule.beta1_pow_index(); ++i) {
ASSERT_FLOAT_EQ(value[i], 0);
}
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999);
// check init_value for random
rule.init_value(value, value + embed_dim, false);
for (auto i = 0u; i < embed_dim; ++i) {
ASSERT_TRUE(value[i] >= rule.min_bound() && value[i] <= rule.max_bound());
}
for (auto i = rule.gsum_index(); i < rule.beta1_pow_index(); ++i) {
ASSERT_FLOAT_EQ(value[i + embed_dim], 0);
}
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9);
ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999);
// check update_value
rule.init_value(value, value + embed_dim, true);
float* grad = new float[embed_dim];
for (auto i = 0u; i < embed_dim; ++i) {
grad[i] = (i + 1) * 1.0;
}
float label[] = {-0.0999999642, -0.099999994, -0.099999994, -0.099999994,
-0.099999994, -0.099999994, -0.099999994, -0.100000001,
-0.100000009, -0.100000001, 0.100000024, 0.200000048,
0.300000072, 0.400000095, 0.500000119, 0.600000143,
0.700000167, 0.800000191, 0.900000215, 1.00000024,
0.000999987125, 0.0039999485, 0.00899988413, 0.015999794,
0.0249996781, 0.0359995365, 0.0489993691, 0.063999176,
0.0809989572, 0.0999987125, 0.809999943, 0.998001039};
rule.update_value(value, value + embed_dim, grad);
for (auto i = 0u; i < value_dim; ++i) { // check update
ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i;
}
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册