未验证 提交 29c6bcbf 编写于 作者: Z zhaocaibei123 提交者: GitHub

memory sparse table & brpc communication upgrade dependency (#36734)

上级 249081b6
......@@ -11,6 +11,7 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
add_subdirectory(common)
add_subdirectory(service)
add_subdirectory(table)
add_subdirectory(test)
......
cc_library(afs_wrapper SRCS afs_warpper.cc DEPS fs ps_framework_proto)
#set_property(GLOBAL PROPERTY COMMON_DEPS afs_warpper)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace distributed {
// AfsClient impl
int AfsClient::initialize(const FsClientParameter& fs_client_param) {
// temporarily implemented with hdfs-client
return initialize(fs_client_param.hadoop_bin(), fs_client_param.uri(),
fs_client_param.user(), fs_client_param.passwd(),
fs_client_param.buffer_size());
}
int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& user, const std::string& passwd,
int buffer_size_param) {
return initialize(hadoop_bin, uri, paddle::string::format_string(
"%s,%s", user.c_str(), passwd.c_str()),
buffer_size_param);
}
int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& ugi, int buffer_size_param) {
// temporarily implemented with hdfs-client
size_t buffer_size = 1L << 25; // 32MB
if (buffer_size_param > static_cast<int>(buffer_size)) {
buffer_size = buffer_size_param;
}
paddle::framework::hdfs_set_buffer_size(buffer_size);
paddle::framework::hdfs_set_command(paddle::string::format_string(
"2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s "
"-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000",
hadoop_bin.c_str(), uri.c_str(), ugi.c_str()));
return 0;
}
// open file in 'w' or 'r'
std::shared_ptr<FsReadChannel> AfsClient::open_r(const FsChannelConfig& config,
uint32_t buffer_size,
int* err_no) {
std::shared_ptr<FsReadChannel> channel =
std::make_shared<FsReadChannel>(buffer_size);
std::shared_ptr<FILE> fp =
paddle::framework::fs_open_read(config.path, err_no, config.deconverter);
channel->open(fp, config);
return channel;
}
std::shared_ptr<FsWriteChannel> AfsClient::open_w(const FsChannelConfig& config,
uint32_t buffer_size,
int* err_no) {
std::shared_ptr<FsWriteChannel> channel =
std::make_shared<FsWriteChannel>(buffer_size);
std::shared_ptr<FILE> fp =
paddle::framework::fs_open_write(config.path, err_no, config.converter);
channel->open(fp, config);
return channel;
}
// remove file in path, path maybe a reg, such as 'part-000-*'
void AfsClient::remove(const std::string& path) {
return paddle::framework::fs_remove(path);
}
void AfsClient::remove_dir(const std::string& dir) {
return paddle::framework::fs_remove(dir);
}
// list files in path, path maybe a dir with reg
std::vector<std::string> AfsClient::list(const std::string& path) {
return paddle::framework::fs_list(path);
}
// exist or not
bool AfsClient::exist(const std::string& dir) {
return paddle::framework::fs_exists(dir);
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
struct FsDataConverter {
std::string converter;
std::string deconverter;
};
struct FsChannelConfig {
std::string path; // path of file
std::string converter; // data converter
std::string deconverter;
};
class FsReadChannel {
public:
FsReadChannel() : _buffer_size(0) {}
explicit FsReadChannel(uint32_t buffer_size) : _buffer_size(buffer_size) {}
virtual ~FsReadChannel() {}
FsReadChannel(FsReadChannel&&) = delete;
FsReadChannel(const FsReadChannel&) = delete;
int open(std::shared_ptr<FILE> fp, const FsChannelConfig& config) {
_file = fp;
return 0;
}
inline int close() {
_file.reset();
return 0;
}
inline uint32_t read_line(std::string& line_data) { // NOLINT
line_data.clear();
char buffer = '\0';
size_t read_count = 0;
while (1 == fread(&buffer, 1, 1, _file.get()) && buffer != '\n') {
++read_count;
line_data.append(&buffer, 1);
}
if (read_count == 0 && buffer != '\n') {
return -1;
}
return 0;
}
private:
uint32_t _buffer_size;
FsChannelConfig _config;
std::shared_ptr<FILE> _file;
};
class FsWriteChannel {
public:
FsWriteChannel() : _buffer_size(0) {}
explicit FsWriteChannel(uint32_t buffer_size) : _buffer_size(buffer_size) {}
virtual ~FsWriteChannel() {}
FsWriteChannel(FsWriteChannel&&) = delete;
FsWriteChannel(const FsWriteChannel&) = delete;
int open(std::shared_ptr<FILE> fp, const FsChannelConfig& config) {
_file = fp;
// the buffer has set in fs.cc
// if (_buffer_size != 0) {
// _buffer = std::shared_ptr<char>(new char[_buffer_size]);
// CHECK(0 == setvbuf(&*_file, _buffer.get(), _IOFBF, _buffer_size));
//}
return 0;
}
inline void flush() { return; }
inline int close() {
flush();
_file.reset();
return 0;
}
inline uint32_t write_line(const char* data, uint32_t size) {
size_t write_count = fwrite_unlocked(data, 1, size, _file.get());
if (write_count != size) {
return -1;
}
write_count = fwrite_unlocked("\n", 1, 1, _file.get());
if (write_count != 1) {
return -1;
}
return 0;
}
inline uint32_t write_line(const std::string& data) {
return write_line(data.c_str(), data.size());
}
private:
uint32_t _buffer_size;
FsChannelConfig _config;
std::shared_ptr<FILE> _file;
std::shared_ptr<char> _buffer;
};
class AfsClient {
public:
AfsClient() {}
virtual ~AfsClient() {}
AfsClient(AfsClient&&) = delete;
AfsClient(const AfsClient&) = delete;
int initialize(const FsClientParameter& fs_client_param);
int initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& user, const std::string& passwd,
int buffer_size_param = (1L << 25));
int initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& ugi, int buffer_size_param = (1L << 25));
// open file in 'w' or 'r'
std::shared_ptr<FsReadChannel> open_r(const FsChannelConfig& config,
uint32_t buffer_size = 0,
int* err_no = nullptr);
std::shared_ptr<FsWriteChannel> open_w(const FsChannelConfig& config,
uint32_t buffer_size = 0,
int* err_no = nullptr);
// remove file in path, path maybe a reg, such as 'part-000-*'
void remove(const std::string& path);
void remove_dir(const std::string& dir);
// list files in path, path maybe a dir with reg
std::vector<std::string> list(const std::string& path);
// exist or not
bool exist(const std::string& dir);
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "butil/time.h"
#include "bvar/latency_recorder.h"
#include "glog/logging.h"
namespace paddle {
namespace distributed {
struct CostProfilerNode {
std::shared_ptr<bvar::LatencyRecorder> recorder;
};
class CostProfiler {
public:
~CostProfiler() {}
static CostProfiler& instance() {
static CostProfiler profiler;
return profiler;
}
void register_profiler(const std::string& label) {
if (_cost_profiler_map.find(label) != _cost_profiler_map.end()) {
return;
}
auto profiler_node = std::make_shared<CostProfilerNode>();
profiler_node->recorder.reset(
new bvar::LatencyRecorder("cost_profiler", label));
_cost_profiler_map[label] = profiler_node;
}
CostProfilerNode* profiler(const std::string& label) {
auto itr = _cost_profiler_map.find(label);
if (itr != _cost_profiler_map.end()) {
return itr->second.get();
}
return NULL;
}
private:
CostProfiler() {}
std::unordered_map<std::string, std::shared_ptr<CostProfilerNode>>
_cost_profiler_map;
};
class CostTimer {
public:
explicit CostTimer(const std::string& label) {
_label = label;
auto& profiler = CostProfiler::instance();
_profiler_node = profiler.profiler(label);
// 如果不在profiler中,则使用log输出耗时信息
_is_print_cost = _profiler_node == NULL;
_start_time_ms = butil::gettimeofday_ms();
}
explicit CostTimer(CostProfilerNode& profiler_node) { // NOLINT
_is_print_cost = false;
_profiler_node = &profiler_node;
_start_time_ms = butil::gettimeofday_ms();
}
~CostTimer() {
if (_is_print_cost) {
LOG(INFO) << "CostTimer label:" << _label
<< ", cost:" << butil::gettimeofday_ms() - _start_time_ms
<< "ms";
} else {
*(_profiler_node->recorder) << butil::gettimeofday_ms() - _start_time_ms;
}
}
private:
std::string _label;
bool _is_print_cost;
uint64_t _start_time_ms;
CostProfilerNode* _profiler_node;
};
} // namespace distributed
} // namespace paddle
......@@ -52,6 +52,20 @@ inline void ADD(int n, const T* x, const T y, T* z) {
}
}
template <typename T>
inline void DIV(int n, const T x, const T* y, T* z) {
for (int i = 0; i < n; ++i) {
z[i] = x / y[i];
}
}
template <typename T>
inline void ELE_MUL(int n, const T* x, const T* y, T* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
static bool StartWith(const std::string& str, const std::string& substr) {
return str.find(substr) == 0;
}
......@@ -91,5 +105,6 @@ inline double GetCurrentUS() {
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
} // namespace distributed
} // namespace paddle
......@@ -144,8 +144,8 @@ class PSEnvironment {
virtual std::vector<uint64_t> get_client_info() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_sign_set) {
client_info.push_back(i);
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_uint64());
}
return client_info;
}
......@@ -250,7 +250,7 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_clients(std::vector<std::string> *host_sign_list,
virtual int32_t set_ps_clients(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
......@@ -265,6 +265,7 @@ class PaddlePSEnvironment : public PSEnvironment {
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
VLOG(1) << "env.set_ps_clients done\n";
return 0;
}
......
......@@ -20,11 +20,13 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace distributed {
......@@ -35,7 +37,7 @@ using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
explicit PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
virtual ~PSClientClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
......@@ -43,12 +45,17 @@ class PSClientClosure : public google::protobuf::Closure {
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) { // NOLINT
_promises.push_back(promise);
}
void add_timer(std::shared_ptr<CostTimer> &timer) { // NOLINT
_timers.push_back(timer);
}
protected:
PSClientCallBack _callback;
std::vector<std::shared_ptr<CostTimer>> _timers;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
......@@ -59,11 +66,11 @@ class PSClient {
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t configure(
virtual int32_t configure( // NOLINT
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, size_t client_id) final;
PSEnvironment &_env, size_t client_id) final; // NOLINT
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms,
......@@ -86,7 +93,7 @@ class PSClient {
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
//清空table数据
// 清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
......@@ -98,7 +105,7 @@ class PSClient {
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; //保留
size_t table_id) = 0; // 保留
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
......@@ -107,6 +114,9 @@ class PSClient {
size_t region_num,
size_t table_id) = 0;
// virtual std::future<int32_t> push_dense(const Region *regions,
// size_t region_num,
// size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
......@@ -212,6 +222,10 @@ class PSClient {
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
// virtual std::future<int32_t> push_sparse(size_t table_id,
// const uint64_t *keys,
// const float **update_values,
// size_t num) = 0;
protected:
virtual int32_t initialize() = 0;
......@@ -222,8 +236,42 @@ class PSClient {
PSEnvironment *_env;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息
_msg_handler_map; // 处理client2client消息
};
template <class T>
class AsyncRequestTask {
public:
AsyncRequestTask() : _promise(std::make_shared<std::promise<int32_t>>()) {}
AsyncRequestTask(T &data, size_t table_id, std::shared_ptr<CostTimer> &timer)
: _table_id(table_id),
_timer(timer),
_promise(std::make_shared<std::promise<int32_t>>()) {
_data = std::move(data);
}
AsyncRequestTask(AsyncRequestTask &data) // NOLINT
: _table_id(data.table_id()),
_timer(data.timer()),
_promise(data.promise()) {
_data = std::move(data.data());
}
~AsyncRequestTask() {}
inline T &data() { return _data; }
inline size_t table_id() { return _table_id; }
inline std::shared_ptr<CostTimer> &timer() { return _timer; }
inline std::future<int32_t> get_future() { return _promise->get_future(); }
inline std::shared_ptr<std::promise<int32_t>> &promise() { return _promise; }
private:
T _data;
size_t _table_id;
std::shared_ptr<CostTimer> _timer;
std::shared_ptr<std::promise<int32_t>> _promise;
};
REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
......
......@@ -17,15 +17,12 @@
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace paddle {
namespace distributed {
struct FsDataConverter {
std::string converter;
std::string deconverter;
};
struct Region {
Region() : data(NULL), size(0) {}
......@@ -50,8 +47,8 @@ struct DataConverter {
class ValueAccessor {
public:
explicit ValueAccessor(){};
virtual ~ValueAccessor(){};
ValueAccessor() {}
virtual ~ValueAccessor() {}
virtual int configure(const TableAccessorParameter& parameter) {
_config = parameter;
......
......@@ -183,5 +183,159 @@ class DAdam : public DenseOptimizer {
float epsilon;
};
// adam optimizer for dense tensor
class DAdamD2Sum : public DenseOptimizer {
public:
explicit DAdamD2Sum(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
lr_hardcode = 5e-6;
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate = (*values)[x].data();
}
if (names[x] == "Param") {
param = (*values)[x].data();
}
if (names[x] == "Moment") {
mom_velocity = (*values)[x].data();
}
if (names[x] == "G2Sum") {
ada_g2sum = (*values)[x].data();
}
if (names[x] == "D2Sum") {
ada_d2sum = (*values)[x].data();
}
if (names[x] == "MomentDecayRate") {
mom_decay_rate = (*values)[x].data();
}
if (names[x] == "AdaDecayRate") {
ada_decay_rate = (*values)[x].data();
}
if (names[x] == "AdaEpsilon") {
ada_epsilon = (*values)[x].data();
}
}
}
void update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
/*
// for debug
std::cout << "before update:\n";
for (int i = 0; i < 3; ++ i) {
std::cout << "param: " << i << " " << *(param+begin+i) <<
"grad: " << *(update_values+begin+i) << "\n";
}*/
std::vector<float> grad, grad2, scale;
grad.resize(update_numel);
grad2.resize(update_numel);
scale.resize(update_numel);
auto blas = GetBlas<float>();
// copy grad
blas.VCOPY(update_numel, update_values + begin, grad.data());
blas.VCOPY(update_numel, update_values + begin, grad2.data());
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "copy grad: " << i << " " << *(grad.data()+begin+i) <<
"copy grad2: " << *(grad2.data()+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "d2sum before: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}*/
// d2sum
blas.SCAL(update_numel, ada_decay_rate[0], ada_d2sum + begin);
ADD<float>(update_numel, ada_d2sum + begin, 1, ada_d2sum + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "d2sum update: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "g2sum before: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}*/
// g2sum
blas.SCAL(update_numel, ada_decay_rate[0], ada_g2sum + begin);
blas.VSQUARE(update_numel, grad2.data(), grad2.data());
blas.VADD(update_numel, ada_g2sum + begin, grad2.data(), ada_g2sum + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "g2sum update: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "mom before: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}*/
// mom
blas.SCAL(update_numel, mom_decay_rate[0], mom_velocity + begin);
blas.SCAL(update_numel, 1 - mom_decay_rate[0], grad.data());
blas.VADD(update_numel, mom_velocity + begin, grad.data(),
mom_velocity + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "mom update: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "scale before: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
// scale
float* scale_ = scale.data();
blas.VDIV(update_numel, ada_g2sum + begin, ada_d2sum + begin, scale_);
ADD<float>(update_numel, scale_, ada_epsilon[0], scale_);
DIV<float>(update_numel, 1 + ada_epsilon[0], scale_, scale_);
SQRT<float>(update_numel, scale_, scale_);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas.SCAL(update_numel, learning_rate[0], scale_);
// TODO(zhaocaibei123): check if there exists elementwise_multiply in blas
// TODO(zhaocaibei123): blas.VMUL
ELE_MUL<float>(update_numel, scale_, mom_velocity + begin, scale_);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update2: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas.VSUB(update_numel, param + begin, scale_, param + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "param update " << i << " " << *(param+begin+i) << "\n";
}*/
}
float* learning_rate;
float lr_hardcode;
float* param;
float* mom_velocity;
float* ada_g2sum;
float* ada_d2sum;
float* mom_decay_rate;
float* ada_decay_rate;
float* ada_epsilon;
};
} // namespace distributed
} // namespace paddle
......@@ -173,6 +173,68 @@ message TensorParallelConfig {
optional int32 tensor_init_seed = 2 [ default = -1 ];
}
enum TableType {
PS_SPARSE_TABLE = 0;
PS_DENSE_TABLE = 1;
}
message TableParameter {
optional uint64 table_id = 1;
optional string table_class = 2;
optional uint64 shard_num = 3;
optional TableType type = 4;
optional TableAccessorParameter accessor = 5;
}
message TableAccessorParameter {
optional string accessor_class = 1;
optional SGDParameter embed_sgd_param = 2;
optional SGDParameter embedx_sgd_param = 3;
optional uint32 fea_dim = 4; // for sparse table, this means field size of one
// value; for dense table, this means total value
// num
optional uint32 embedx_dim = 5; // embedx feature size
optional uint32 embedx_threshold = 6; // embedx feature create threshold
optional CtrAccessorParameter ctr_accessor_param = 7;
}
// TODO(guanqun): add NaiveSGD/Adam...
message SGDParameter {
optional string name = 1;
optional SGDRuleParameter adagrad = 2;
}
message SGDRuleParameter {
optional double learning_rate = 1;
optional double initial_g2sum = 2;
optional double initial_range = 3 [ default = 0 ];
repeated float weight_bounds = 4;
}
message CtrAccessorParameter {
optional float nonclk_coeff = 1; // to calculate show_click_score
optional float click_coeff = 2; // to calculate show_click_score
optional float base_threshold =
3; // show_click_score > base_threshold, this feature can be saved
optional float delta_threshold =
4; // delta_score > delta_threshold, this feature can be saved
optional float delta_keep_days =
5; // unseen_day < delta_keep_days, this feature can be saved
optional float show_click_decay_rate = 6; // show/click will update to
// show/click *
// show_click_decay_rate after a day
optional float delete_threshold = 7; // threshold to shrink a feasign
optional float delete_after_unseen_days = 8;
optional int32 ssd_unseenday_threshold = 9;
}
message FsClientParameter {
optional string uri = 1;
optional string user = 2;
optional string passwd = 3;
optional string hadoop_bin = 4;
}
message DistributedStrategy {
// bool options
optional Mode mode = 1 [ default = COLLECTIVE ];
......@@ -210,6 +272,7 @@ message DistributedStrategy {
optional bool asp = 33 [ default = false ];
optional bool fuse_grad_merge = 34 [ default = false ];
optional bool semi_auto = 35 [ default = false ];
optional bool adam_d2sum = 36 [ default = true ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......@@ -225,6 +288,9 @@ message DistributedStrategy {
optional HybridConfig hybrid_configs = 112;
optional TensorParallelConfig tensor_parallel_configs = 113;
optional TrainerDescConfig trainer_desc_configs = 114;
optional TableParameter downpour_table_param = 115;
optional FsClientParameter fs_client_param = 116;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
optional GradientScaleConfig gradient_scale_configs = 203;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册