未验证 提交 cca57c4a 编写于 作者: Z zhaocaibei123 提交者: GitHub

Ssd sparse table (#41812)

* [cherry-pick2.3]fix compile bug of windows cuda11.5 (#41464)

cherry-pick

fix compile bug of windows cuda11.5 #41433

* fix bug of missing boost when compile cache.cc (#41449)

【chery-pick #41430】fix bug of random compile failure, due to incorrect compile order of dependencies

* Fix eager try catch (#41438) (#41477)

[Cherry-Pick]Fix eager try catch (#41438)

* Cherry-pick-PR41407, fix device_id bug for final_state op in multiprocess testcase (#41407) (#41475)

Cherry-pick PR #41407

* [BugFix] Add error hint for one_hot gpu version (#41335) (#41495)

* add one_hot gpu hint

* move allow_out_of_range judgement

* delete useless unittest

* fix bugs of reshape double grad infermeta (#41459) (#41493)

* [cherrypick-2.3] modify infer gpu memory strategy (#41427), remove cudnn_deterministic=True (#41341)  (#41491)
Co-authored-by: NJingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>

* [Cherry-pick][ROCm] fix dcu error in device event base, test=develop (#41523)

Cherry-pick of #41521

* [Cherry-Pick]Cherry pick PR41200, PR41474, PR41382 (#41509)

* Use `self`as a parameter of _hash_with_id function to avoid error caused by hash_id reuse (#41200)

* Add fill_constant_batch_size YAML and UT (#41474)

* Switch some dy2st UT to eager mode (#41382)

* Sitch some dy2st UT to eager mode

* Fix test_lstm and remove test_transformer

* Run test_resnet_v2 in old dy mode

* Unittest recover (#41431)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix

* remove some interface

* fix

* remove

* code style

* recover

* fix

* remove code unused

* remove some unused table & accessor & CommonDenseTable => MemoryDenseTable

* fix

* fix

* fix

* recover

* remove unused code

* recover unittest

* fix

* remove

* fix

* remove code unuseful

* remove

* fix

* recover

* remove
Co-authored-by: Nesythan <esythan@126.com>

* add ssd sparse table

* fix

* add cache shuffle

* fix

* fix

* fix

* fix

* fix

* fix

* add unit test

* fix
Co-authored-by: zhouweiwei2014's avatarZhou Wei <1183042833@qq.com>
Co-authored-by: NSing_chan <51314274+betterpig@users.noreply.github.com>
Co-authored-by: N0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: Npangyoki <pangyoki@126.com>
Co-authored-by: NSiming Dai <908660116@qq.com>
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: NZhang Jun <ewalker@live.cn>
Co-authored-by: NJingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: NQi Li <qili93@qq.com>
Co-authored-by: Nesythan <esythan@126.com>
上级 4fd190d5
......@@ -357,10 +357,8 @@ if (WITH_PSCORE)
include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)
if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
if(WITH_XBYAK)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <unordered_map>
namespace paddle {
namespace distributed {
class TopkCalculator {
public:
TopkCalculator(int shard_num, size_t k)
: _shard_num(shard_num), _total_max_size(k) {
_shard_max_size = _total_max_size / shard_num;
_shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1;
for (int i = 0; i < shard_num; ++i) {
_mpq.emplace(i, std::priority_queue<double, std::vector<double>,
std::greater<double>>());
}
}
~TopkCalculator() {}
bool push(int shard_id, double value) {
if (_mpq.find(shard_id) == _mpq.end()) {
return false;
}
auto &pq = _mpq[shard_id];
if (pq.size() < _shard_max_size) {
pq.push(value);
} else {
if (pq.top() < value) {
pq.pop();
pq.push(value);
}
}
return true;
}
// TODO 再进行一次堆排序merge各个shard的结果
int top() {
double total = 0;
for (const auto &item : _mpq) {
auto &pq = item.second;
if (!pq.empty()) {
total += pq.top();
}
}
return total / _shard_num;
}
private:
std::unordered_map<int, std::priority_queue<double, std::vector<double>,
std::greater<double>>>
_mpq;
int _shard_num;
size_t _total_max_size;
size_t _shard_max_size;
};
} // namespace distributed
} // namespace paddle
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
if(WITH_HETERPS)
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb)
else()
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
endif()
brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
......
......@@ -429,6 +429,82 @@ std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle";
return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold});
}
std::future<int32_t> BrpcPsClient::CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path";
std::vector<std::string> param;
param.push_back(path);
param.push_back(mode);
param.push_back(cache_threshold);
for (size_t i = 0; i < tables.size(); i++) {
param.push_back(std::to_string(tables[i]));
}
return SendSaveCmd(0, PS_CACHE_SHUFFLE, param);
}
std::future<int32_t> BrpcPsClient::SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode});
}
std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
int cmd_id = PS_GET_CACHE_THRESHOLD;
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
std::string cur_res = closure->get_response(i, cmd_id);
cache_thresholds[i] = std::stod(cur_res);
}
double sum_threshold = 0.0;
int count = 0;
for (auto t : cache_thresholds) {
if (t >= 0) {
sum_threshold += t;
++count;
}
}
if (count == 0) {
cache_threshold = 0;
} else {
cache_threshold = sum_threshold / count;
}
VLOG(1) << "client get cache threshold: " << cache_threshold;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(10800000);
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
......
......@@ -219,6 +219,20 @@ class BrpcPsClient : public PSClient {
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);
std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) override;
std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold);
std::future<int32_t> SaveCache(uint32_t table_id, const std::string &path,
const std::string &mode) override;
std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) override;
void PrintQueueSize();
void PrintQueueSizeThread();
......
......@@ -28,6 +28,13 @@ class RpcController;
} // namespace protobuf
} // namespace google
DEFINE_int32(pserver_timeout_ms_s2s, 10000,
"pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s, 10000,
"pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s, "pooled",
"pserver connection_type[pooled:single]");
namespace paddle {
namespace distributed {
......@@ -93,6 +100,84 @@ uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
return host.rank;
}
int32_t BrpcPsServer::StartS2S() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms_s2s;
options.connection_type = FLAGS_pserver_connection_type_s2s;
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms_s2s;
options.max_retry = 3;
std::vector<PSHost> pserver_list = _environment->GetPsServers();
_pserver_channels.resize(pserver_list.size());
VLOG(2) << "pserver start s2s server_list size: " << _pserver_channels.size();
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < pserver_list.size(); ++i) {
server_ip_port.assign(pserver_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(pserver_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "pserver connect to pserver:" << server_ip_port
<< " Failed!";
}
os << server_ip_port << ",";
}
LOG(INFO) << "pserver connect success: " << os.str();
return 0;
}
std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (to_pserver_id >= _pserver_channels.size()) {
LOG(FATAL) << "to_pserver_id is out of range pservers, which size is "
<< _pserver_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done;
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(101);
closure->request(0)->set_client_id(_rank);
closure->request(0)->set_table_id(0);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, int pserver_id,
const std::string &msg) {
if (msg.length() == 0) {
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
return 0;
}
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr);
if (ar.Cursor() == ar.Finish()) {
LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response";
return 0;
}
std::vector<std::pair<uint64_t, std::string>> data;
while (ar.Cursor() < ar.Finish()) {
data.push_back(ar.Get<std::pair<uint64_t, std::string>>());
}
CHECK(ar.Cursor() == ar.Finish());
this->_shuffled_ins->Write(std::move(data));
return 0;
}
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
int32_t BrpcPsService::Initialize() {
......@@ -117,6 +202,14 @@ int32_t BrpcPsService::Initialize() {
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
......@@ -168,19 +261,29 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
response->set_err_msg("");
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
if (request->cmd_id() < 100) {
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
} else {
int service_ret = _server->HandlePServer2PServerMsg(
request->cmd_id(), request->client_id(), request->data());
if (service_ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_pserver2pserver_msg failed");
}
}
}
......@@ -561,6 +664,90 @@ int32_t BrpcPsService::SaveAllTable(Table *table,
return 0;
}
int32_t BrpcPsService::SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 3, path&mode");
return -1;
}
table->Flush();
int32_t feasign_size = 0;
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size = table->SaveCache(request.params(0), request.params(1),
_server->_shuffled_ins);
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t BrpcPsService::CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
// start cache shuffle
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(response, -1,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold");
return -1;
}
table->Flush();
double cache_threshold = std::stod(request.params(2));
LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server->StartS2S();
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
const std::string &msg)>
send_msg_func = [this](int msg_type, int to_pserver_id,
const std::string &msg) -> std::future<int32_t> {
return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
};
std::vector<Table *> table_ptrs;
for (size_t i = 3; i < request.params_size(); ++i) {
int table_id = std::stoi(request.params(i));
Table *table_ptr = _server->GetTable(table_id);
table_ptrs.push_back(table_ptr);
}
if (table_ptrs.empty()) {
table_ptrs.push_back(table);
}
table->CacheShuffle(request.params(0), request.params(1), cache_threshold,
send_msg_func, _server->_shuffled_ins, table_ptrs);
return 0;
}
int32_t BrpcPsService::GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->Flush();
double cache_threshold = 0.0;
cache_threshold = table->GetCacheThreshold();
if (cache_threshold < 0) {
LOG(WARNING) << "wrong threshold: " << cache_threshold;
}
std::stringstream ss;
ss << std::setprecision(15) << cache_threshold;
std::string cache_threshold_str = ss.str();
response.set_data(cache_threshold_str);
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......
......@@ -53,6 +53,12 @@ class BrpcPsServer : public PSServer {
}
int32_t Port();
virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id,
const std::string &msg) override;
private:
virtual int32_t Initialize();
mutable std::mutex mutex_;
......@@ -123,6 +129,16 @@ class BrpcPsService : public PsBaseService {
int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t CacheShuffle(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
......
......@@ -198,6 +198,7 @@ class PSClient {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandleClient2ClientMsg(int msg_type, int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
......@@ -239,6 +240,46 @@ class PSClient {
const float **update_values,
size_t num) = 0;
// for save cache
virtual std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
protected:
virtual int32_t Initialize() = 0;
size_t _client_id;
......
......@@ -65,6 +65,8 @@ enum PsCmdID {
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
}
message PsRequestMessage {
......
......@@ -67,6 +67,8 @@ int32_t PSServer::Configure(
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param();
......
......@@ -89,6 +89,45 @@ class PSServer {
return &_table_map;
}
// for cache
virtual int32_t StartS2S() { return 0; }
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegistePServer2PServerMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandlePServer2PServerMsg(int msg_type, int from_pserver_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
if (msg_type == 101) {
return ReceiveFromPServer(msg_type, from_pserver_id, msg);
} else {
LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type;
return -1;
}
}
return itr->second(msg_type, from_pserver_id, msg);
}
virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id,
const std::string &msg) {
LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer";
return -1;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t Initialize() = 0;
......@@ -97,6 +136,7 @@ class PSServer {
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected:
std::shared_ptr<framework::Scope> scope_;
......
......@@ -18,17 +18,12 @@ include_directories(${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmc
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(EXTERN_DEP "")
if(WITH_HETERPS)
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
set(EXTERN_DEP rocksdb)
else()
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
endif()
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
#set(EXTERN_DEP rocksdb)
cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper
simple_threadpool xxhash generator ${EXTERN_DEP})
simple_threadpool xxhash generator)
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......@@ -41,13 +36,13 @@ set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DI
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(sparse_table SRCS memory_sparse_table.cc ssd_sparse_table.cc memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table rocksdb)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
cc_library(table SRCS table.cc DEPS sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp)
......@@ -117,6 +117,11 @@ class ValueAccessor {
virtual bool Save(float* value, int param) = 0;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) {}
// 判断该value是否保存到ssd
virtual bool SaveSSD(float* value) = 0;
//
virtual bool SaveCache(float* value, int param,
double global_cache_threshold) = 0;
// keys不存在时,为values生成随机值
virtual int32_t Create(float** value, size_t num) = 0;
......
......@@ -38,13 +38,13 @@
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle {
......
......@@ -34,6 +34,8 @@ int CtrCommonAccessor::Initialize() {
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
......@@ -77,6 +79,25 @@ bool CtrCommonAccessor::Shrink(float* value) {
return false;
}
bool CtrCommonAccessor::SaveCache(float* value, int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
return common_feature_value.Show(value) > global_cache_threshold;
}
return false;
}
bool CtrCommonAccessor::SaveSSD(float* value) {
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrCommonAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......
......@@ -148,6 +148,9 @@ class CtrCommonAccessor : public ValueAccessor {
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value, int param,
double global_cache_threshold) override;
bool SaveSSD(float* value) override;
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
......
......@@ -74,25 +74,26 @@ bool CtrDoubleAccessor::Shrink(float* value) {
}
return false;
}
bool CtrDoubleAccessor::SaveSSD(float* value) {
if (CtrDoubleFeatureValue::UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
// bool CtrDoubleAccessor::save_cache(
// float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
// CtrDoubleFeatureValue::Click(value)) >= base_threshold
// && CtrDoubleFeatureValue::UnseenDays(value) <=
// delta_keep_days) {
// return CtrDoubleFeatureValue::Show(value) >
// global_cache_threshold;
// }
// return false;
// }
bool CtrDoubleAccessor::SaveCache(float* value, int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
CtrDoubleFeatureValue::Click(value)) >= base_threshold &&
CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
return CtrDoubleFeatureValue::Show(value) > global_cache_threshold;
}
return false;
}
bool CtrDoubleAccessor::Save(float* value, int param) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......
......@@ -167,6 +167,8 @@ class CtrDoubleAccessor : public ValueAccessor {
// param = 1, save delta feature
// param = 3, save all feature with time decay
virtual bool Save(float* value, int param) override;
bool SaveCache(float* value, int param,
double global_cache_threshold) override;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) override;
// 判断该value是否保存到ssd
......
......@@ -11,9 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h>
#include <rocksdb/db.h>
#include <rocksdb/filter_policy.h>
......@@ -154,6 +153,5 @@ class RocksDBHandler {
std::vector<rocksdb::ColumnFamilyHandle*> _handles;
rocksdb::DB* _db;
};
}
}
#endif
} // distributed
} // paddle
......@@ -23,14 +23,17 @@
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_bool(pserver_print_missed_key_num_every_push, false,
"pserver_print_missed_key_num_every_push");
DEFINE_bool(pserver_create_value_when_push, true,
"pserver create value when push");
DEFINE_bool(pserver_enable_create_feasign_randomly, false,
"pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");
namespace paddle {
namespace distributed {
// TODO(zhaocaibei123): configure
bool FLAGS_pserver_create_value_when_push = true;
int FLAGS_pserver_table_save_max_retry = 3;
bool FLAGS_pserver_enable_create_feasign_randomly = false;
int32_t MemorySparseTable::Initialize() {
_shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
......@@ -142,7 +145,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
}
if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
exit(-1);
}
......@@ -213,7 +216,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
<< file_list[file_start_idx + i]
<< " , retry_num=" << retry_num;
}
if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
exit(-1);
}
......@@ -293,7 +296,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
exit(-1);
}
......
......@@ -62,9 +62,11 @@ class MemorySparseTable : public Table {
int32_t InitializeShard() override { return 0; }
int32_t InitializeValue();
int32_t Load(const std::string& path, const std::string& param) override;
virtual int32_t Load(const std::string& path,
const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
virtual int32_t Save(const std::string& path,
const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t SaveLocalFS(const std::string& path, const std::string& param,
......@@ -83,7 +85,7 @@ class MemorySparseTable : public Table {
int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
int32_t Flush() override;
int32_t Shrink(const std::string& param) override;
virtual int32_t Shrink(const std::string& param) override;
void Clear() override;
void* GetShard(size_t shard_idx) override {
......
......@@ -135,6 +135,11 @@ class SparseAccessor : public ValueAccessor {
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value, int param, double global_cache_threshold) {
return false;
}
bool SaveSSD(float* value) { return false; }
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
namespace paddle {
namespace distributed {
class SSDSparseTable : public MemorySparseTable {
public:
typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
SSDSparseTable() {}
virtual ~SSDSparseTable() {}
int32_t Initialize() override;
int32_t InitializeShard() override;
// exchange data
int32_t UpdateTable();
int32_t Pull(TableContext& context) override {
CHECK(context.value_type == Sparse);
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_);
}
int32_t Push(TableContext& context) override {
const uint64_t* keys = context.push_context.keys;
const float* values = context.push_context.values;
size_t num = context.num;
return PushSparse(keys, values, num);
}
virtual int32_t PullSparse(float* pull_values, const uint64_t* keys,
size_t num);
virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num);
int32_t Flush() override { return 0; }
virtual int32_t Shrink(const std::string& param) override;
virtual void Clear() override {
for (size_t i = 0; i < _real_local_shard_num; ++i) {
_local_shards[i].clear();
}
}
virtual int32_t Save(const std::string& path,
const std::string& param) override;
virtual int32_t SaveCache(
const std::string& path, const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) override;
virtual double GetCacheThreshold() override { return _local_show_threshold; }
virtual int64_t CacheShuffle(
const std::string& path, const std::string& param, double cache_threshold,
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
std::string& msg)>
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) override;
//加载path目录下数据
virtual int32_t Load(const std::string& path,
const std::string& param) override;
//加载path目录下数据[start_idx, end_idx)
virtual int32_t Load(size_t start_idx, size_t end_idx,
const std::vector<std::string>& file_list,
const std::string& param);
int64_t LocalSize();
private:
RocksDBHandler* _db;
int64_t _cache_tk_size;
double _local_show_threshold{0.0};
};
} // namespace distributed
} // namespace paddle
......@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/tensor_accessor.h"
#include "paddle/fluid/distributed/ps/table/tensor_table.h"
......@@ -37,6 +38,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseTable);
REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor);
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -107,6 +108,26 @@ class Table {
// 指定保存路径
virtual int32_t Save(const std::string &path,
const std::string &converter) = 0;
// for cache
virtual int32_t SaveCache(
const std::string &path, const std::string &param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel) {
return 0;
}
virtual int64_t CacheShuffle(
const std::string &path, const std::string &param, double cache_threshold,
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
std::string &msg)>
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel,
const std::vector<Table *> &table_ptrs) {
return 0;
}
virtual double GetCacheThreshold() { return 0.0; }
virtual int32_t SetShard(size_t shard_idx, size_t shard_num) {
_shard_idx = shard_idx;
......
......@@ -38,6 +38,12 @@ class CommMergeAccessor : public ValueAccessor {
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool Save(float * /*value*/, int /*param*/);
bool SaveCache(float *value, int param, double global_cache_threshold) {
return false;
}
bool SaveSSD(float *value) { return false; }
// keys不存在时,为values生成随机值
virtual int32_t Create(float **value, size_t num);
// 从values中选取到select_values中
......
......@@ -754,6 +754,46 @@ std::future<int32_t> FleetWrapper::SendClientToClientMsg(
return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg);
}
double FleetWrapper::GetCacheThreshold(int table_id) {
double cache_threshold = 0.0;
auto ret = worker_ptr_->Flush();
ret.wait();
ret = worker_ptr_->GetCacheThreshold(table_id, cache_threshold);
ret.wait();
if (cache_threshold < 0) {
LOG(ERROR) << "get cache threshold failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return cache_threshold;
}
void FleetWrapper::CacheShuffle(int table_id, const std::string& path,
const int mode, const double cache_threshold) {
auto ret = worker_ptr_->CacheShuffle(table_id, path, std::to_string(mode),
std::to_string(cache_threshold));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "cache shuffle failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
const int mode) {
auto ret = worker_ptr_->SaveCache(table_id, path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
......
......@@ -259,6 +259,11 @@ class FleetWrapper {
// for init worker
void InitGFlag(const std::string& gflags);
double GetCacheThreshold(int table_id);
void CacheShuffle(int table_id, const std::string& path, const int mode,
const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode);
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
......@@ -116,6 +116,10 @@ message TableParameter {
optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ];
optional GraphParameter graph_parameter = 9;
// for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
}
message TableAccessorParameter {
......
......@@ -78,7 +78,11 @@ void BindDistFleetWrapper(py::module* m) {
.def("set_clients", &FleetWrapper::SetClients)
.def("get_client_info", &FleetWrapper::GetClientsInfo)
.def("create_client2client_connection",
&FleetWrapper::CreateClient2ClientConnection);
&FleetWrapper::CreateClient2ClientConnection)
.def("client_flush", &FleetWrapper::ClientFlush)
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache);
}
void BindPSHost(py::module* m) {
......
......@@ -100,6 +100,14 @@ inline int str_to_float(const char* str, float* v) {
return index;
}
inline float* str_to_float(std::string& str) {
return (float*)const_cast<char*>(str.c_str());
}
inline float* str_to_float(const char* str) {
return (float*)const_cast<char*>(str);
}
// checks whether the test string is a suffix of the input string.
bool ends_with(std::string const& input, std::string const& test);
......
......@@ -77,6 +77,7 @@ stop_worker = fleet.stop_worker
distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model
save_persistables = fleet.save_persistables
save_cache_model = fleet.save_cache_model
load_model = fleet.load_model
minimize = fleet.minimize
distributed_model = fleet.distributed_model
......
......@@ -869,6 +869,11 @@ class Fleet(object):
self._runtime_handle._save_persistables(executor, dirname, main_program,
mode)
@is_non_distributed_check
@inited_runtime_handler
def save_cache_model(self, dirname, **configs):
return self._runtime_handle._save_cache_model(dirname, **configs)
def shrink(self, threshold=None):
self._runtime_handle._shrink(threshold)
......
......@@ -1315,6 +1315,30 @@ class TheOnePSRuntime(RuntimeBase):
def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs)
def _save_cache_model(self, dirname, **kwargs):
mode = kwargs.get("mode", 0)
table_id = kwargs.get("table_id", 0)
self._worker.client_flush()
fleet.util.barrier()
cache_threshold = 0.0
if self.role_maker._is_first_worker():
cache_threshold = self._worker.get_cache_threshold(table_id)
#check cache threshold right or not
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.cache_shuffle(table_id, dirname, mode, cache_threshold)
fleet.util.barrier()
feasign_num = -1
if self.role_maker._is_first_worker():
feasign_num = self._worker.save_cache(table_id, dirname, mode)
fleet.util.barrier()
return feasign_num
def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
......
......@@ -339,5 +339,9 @@ class TestDistCTR2x2(FleetDistRunnerBase):
if dirname:
fleet.save_persistables(exe, dirname=dirname)
cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None)
if cache_dirname:
fleet.save_cache_model(cache_dirname)
if __name__ == "__main__":
runtime_main(TestDistCTR2x2)
......@@ -39,6 +39,8 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase):
"http_proxy": "",
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"SAVE_CACHE_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model",
"LOG_PREFIX": self.__class__.__name__,
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册