未验证 提交 202bfab1 编写于 作者: T tangwei12 提交者: GitHub

Feature/large scale kv save base/delta (#27470)

* add size method for large scale

* add large scale UT

* add ut for checkpoint
上级 aa3b4ed7
...@@ -56,7 +56,7 @@ endif() ...@@ -56,7 +56,7 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op scale_op) DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op checkpoint_notify_op scale_op )
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
...@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, ...@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname, const std::string& dirname,
const std::string& varname, const std::string& varname,
const int mode,
int64_t time_out) { int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(varname); req.set_varname(varname);
......
...@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient { ...@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname, const std::string& ep, const std::string& dirname,
const std::string& varname, const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -422,6 +422,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -422,6 +422,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname, const std::string& dirname,
const std::string& varname, const std::string& varname,
const int mode,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
...@@ -435,6 +436,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -435,6 +436,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(varname); req.set_varname(varname);
req.set_table_name(std::to_string(mode));
req.set_out_varname(dirname); req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
......
...@@ -258,7 +258,7 @@ class GRPCClient : public RPCClient { ...@@ -258,7 +258,7 @@ class GRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname, const std::string& ep, const std::string& dirname,
const std::string& varname, const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify( VarHandlePtr AsyncDistributeNotify(
......
...@@ -398,12 +398,13 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -398,12 +398,13 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname(); std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname(); std::string checkpoint_dir = request_->OutVarname();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
std::string table_name = request_->TableName();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir; << ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir); trainer_id, checkpoint_dir, table_name);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -245,6 +245,7 @@ struct VALUE { ...@@ -245,6 +245,7 @@ struct VALUE {
std::vector<std::string> names_; std::vector<std::string> names_;
int count_; int count_;
bool seen_after_last_save_;
int unseen_days_; int unseen_days_;
bool is_entry_; bool is_entry_;
std::vector<std::vector<float>> values_; std::vector<std::vector<float>> values_;
...@@ -321,6 +322,7 @@ class ValueBlock { ...@@ -321,6 +322,7 @@ class ValueBlock {
auto value = new VALUE(value_names_); auto value = new VALUE(value_names_);
value->set(values); value->set(values);
value->seen_after_last_save_ = true;
value->count_ = count; value->count_ = count;
values_[id] = value; values_[id] = value;
} }
...@@ -589,9 +591,9 @@ class SparseVariable { ...@@ -589,9 +591,9 @@ class SparseVariable {
} }
} }
void Save(const std::string &dirname) { void Save(const std::string &dirname, const int mode = 0) {
rwlock_->WRLock(); rwlock_->WRLock();
VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " begin"; VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " begin";
MkDirRecursively(dirname.c_str()); MkDirRecursively(dirname.c_str());
...@@ -600,22 +602,15 @@ class SparseVariable { ...@@ -600,22 +602,15 @@ class SparseVariable {
auto filename = string::Sprintf("%s/%s", dirname, value_name); auto filename = string::Sprintf("%s/%s", dirname, value_name);
filenames.push_back(filename); filenames.push_back(filename);
} }
SaveToSelectedRows(filenames, meta_.value_names);
// // save sparse to text SaveToSelectedRows(filenames, meta_.value_names, mode);
// std::vector<std::string> txt_filenames; VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " done";
// for (auto &value_name : meta_.value_names) {
// auto filename = string::Sprintf("%s/%s.txt", dirname, value_name);
// txt_filenames.push_back(filename);
// }
// SaveToText(txt_filenames, meta_.value_names);
VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " done";
rwlock_->UNLock(); rwlock_->UNLock();
} }
void SaveToSelectedRows(const std::vector<std::string> &filenames, void SaveToSelectedRows(const std::vector<std::string> &filenames,
const std::vector<std::string> &valuenames) { const std::vector<std::string> &valuenames,
const int mode) {
for (auto &value_name : valuenames) { for (auto &value_name : valuenames) {
auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(), auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(),
value_name); value_name);
...@@ -629,14 +624,34 @@ class SparseVariable { ...@@ -629,14 +624,34 @@ class SparseVariable {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
int64_t ids_num = 0; std::vector<int64_t> ids;
for (auto &block : shard_blocks_) { for (auto &block : shard_blocks_) {
ids_num += block->values_.size(); for (auto value : block->values_) {
if (mode == 0) {
ids.push_back(value.first);
} else {
bool id_need_save = false;
// save all params
if (mode == 1) {
id_need_save = true;
} else {
id_need_save = value.second->seen_after_last_save_;
}
if (id_need_save) {
ids.push_back(value.first);
}
value.second->seen_after_last_save_ = false;
}
}
} }
VLOG(3) << "save " << ids.size() << " feasigns for " << meta_.name
<< " with mode: " << mode;
std::vector<std::shared_ptr<framework::Variable>> variables; std::vector<std::shared_ptr<framework::Variable>> variables;
std::vector<float *> tensors; std::vector<float *> tensors;
std::vector<int64_t> ids;
std::vector<int64_t> dims; std::vector<int64_t> dims;
for (int i = 0; i < static_cast<int>(filenames.size()); i++) { for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
...@@ -645,7 +660,7 @@ class SparseVariable { ...@@ -645,7 +660,7 @@ class SparseVariable {
auto *slr = var->GetMutable<framework::SelectedRows>(); auto *slr = var->GetMutable<framework::SelectedRows>();
auto *src_t = slr->mutable_value(); auto *src_t = slr->mutable_value();
src_t->Resize({ids_num, dim}); src_t->Resize({static_cast<int64_t>(ids.size()), dim});
auto *value = src_t->mutable_data<float>(place); auto *value = src_t->mutable_data<float>(place);
dims.push_back(dim); dims.push_back(dim);
...@@ -653,20 +668,17 @@ class SparseVariable { ...@@ -653,20 +668,17 @@ class SparseVariable {
tensors.push_back(value); tensors.push_back(value);
} }
int64_t offset = 0; std::vector<std::vector<std::vector<float> *>> values;
for (auto &block : shard_blocks_) { Get(ids, valuenames, &values);
for (auto value : block->values_) {
ids.push_back(value.first);
std::vector<std::vector<float> *> vss = value.second->get(valuenames);
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(),
sizeof(float) * dims[i]);
}
offset += 1; int64_t offset = 0;
for (auto &vss : values) {
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(),
sizeof(float) * dims[i]);
} }
offset += 1;
} }
for (auto &var : variables) { for (auto &var : variables) {
......
...@@ -274,13 +274,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname, ...@@ -274,13 +274,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname,
const int trainer_id, const int trainer_id,
const std::string &out_var_name, const std::string &out_var_name,
const std::string &table_name) { const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name; VLOG(4) << "receive save var " << varname << " with path " << out_var_name
<< " mode " << table_name;
int mode = std::stoi(table_name);
auto *ins = distributed::LargeScaleKV::GetInstance(); auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name); ins->Get(varname)->Save(out_var_name, mode);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
return true; return true;
} }
......
...@@ -87,7 +87,8 @@ class RPCClient { ...@@ -87,7 +87,8 @@ class RPCClient {
virtual VarHandlePtr AsyncCheckpointNotify( virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname, const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify( virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <stdlib.h> #include <stdlib.h>
#include <unistd.h> #include <unistd.h>
#include <chrono> // NOLINT
#include <memory> #include <memory>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -26,6 +27,7 @@ limitations under the License. */ ...@@ -26,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
...@@ -35,6 +37,7 @@ namespace platform = paddle::platform; ...@@ -35,6 +37,7 @@ namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed; namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table_read); USE_NO_KERNEL_OP(lookup_sparse_table_read);
USE_NO_KERNEL_OP(checkpoint_notify);
USE_OP(scale); USE_OP(scale);
std::unique_ptr<distributed::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
...@@ -122,7 +125,7 @@ void StartServer(const std::string& rpc_name) { ...@@ -122,7 +125,7 @@ void StartServer(const std::string& rpc_name) {
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
distributed::HeartBeatMonitor::Init(2, true, "w@grad"); // distributed::HeartBeatMonitor::Init(1, true, "w@grad");
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
...@@ -232,3 +235,110 @@ TEST(SENDANDRECV, CPU) { ...@@ -232,3 +235,110 @@ TEST(SENDANDRECV, CPU) {
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr); g_req_handler.reset(nullptr);
} }
void StartCheckpointServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
std::vector<distributed::SparseMeta> metas;
auto meta = distributed::SparseMeta();
meta.name = "embedding.block0";
meta.value_names = {"Param"};
meta.value_dims = {64};
meta.mode = distributed::Mode::training;
meta.grad_name = "embedding@Grad";
meta.cached_varnames = {"kSparseIds"};
meta.initializer_attrs = {"fill_constant&1.0"};
meta.entry = "none";
metas.push_back(meta);
distributed::LargeScaleKV::Init(metas);
auto* ins = distributed::LargeScaleKV::GetInstance();
ins->Get("embedding.block0")->Init({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(LARGE_SCALE_CHECKPOINT, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
g_req_handler.reset(new distributed::RequestCheckpointHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartCheckpointServer,
distributed::kRequestCheckpoint);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
auto save_path =
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base",
"embedding", "embedding.block0");
int mode = 0;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
save_path =
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta",
"embedding", "embedding.block0");
mode = 1;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
paddle::framework::AttributeMap attrs;
std::vector<std::string> eps = {ep};
attrs["endpoints"] = eps;
attrs["dirname"] = std::string("/tmp/large_scale_table/delta1");
attrs["varname"] = std::string("embedding");
attrs["mode"] = 2;
std::vector<std::string> slices = {"embedding.block0"};
attrs["slice_varnames"] = slices;
std::vector<std::string> remotes = {"embedding.block0"};
attrs["remote_varnames"] = remotes;
auto ops =
framework::OpRegistry::CreateOp("checkpoint_notify", {}, {}, attrs, true);
ops->Run(scope, place);
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
...@@ -42,8 +42,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -42,8 +42,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoints"); Attr<std::vector<std::string>>("endpoints");
std::string dirname = Attr<std::string>("dirname"); std::string dirname = Attr<std::string>("dirname");
std::string varname = Attr<std::string>("varname"); std::string varname = Attr<std::string>("varname");
auto is_slice = Attr<bool>("is_slice"); auto mode = Attr<int>("mode");
VLOG(1) << "is_slice: " << is_slice;
if (mode != 0 && mode != 1 && mode != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"mode expected in [0/1/2], but got %d", mode));
}
std::vector<std::string> slice_varnames = std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames"); Attr<std::vector<std::string>>("slice_varnames");
...@@ -58,11 +62,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -58,11 +62,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
auto save_path = auto save_path =
string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]); string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path, rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i],
remote_varnames[i]); mode);
VLOG(3) << "checkpoint notify sending with path: " << save_path VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i]; << " and var:" << slice_varnames[i] << " to " << epmap[i]
<< " with mode " << mode;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true, rpc_client->Wait(), true,
...@@ -85,9 +90,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -85,9 +90,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"slice_varnames", "(string vector) the slice vars need to be saved"); "slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved"); "remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<bool>( AddAttr<int>("mode", "mode=0/1/2 means nothing/save base/save delta")
"is_slice", .SetDefault(0);
"is_slice=True means the var has been slice by parameter server");
AddComment(R"DOC( AddComment(R"DOC(
CheckpointNotify operator CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
......
...@@ -109,10 +109,15 @@ void BindLargeScaleKV(py::module* m) { ...@@ -109,10 +109,15 @@ void BindLargeScaleKV(py::module* m) {
auto* sparse_variable = self.Get(table_name); auto* sparse_variable = self.Get(table_name);
sparse_variable->Load(dir); sparse_variable->Load(dir);
}) })
.def("save", [](LargeScaleKV& self, const std::string& table_name, .def("save",
const std::string& dir) { [](LargeScaleKV& self, const std::string& table_name,
const std::string& dir) {
auto* sparse_variable = self.Get(table_name);
sparse_variable->Save(dir);
})
.def("size", [](LargeScaleKV& self, const std::string& table_name) {
auto* sparse_variable = self.Get(table_name); auto* sparse_variable = self.Get(table_name);
sparse_variable->Save(dir); return sparse_variable->Size();
}); });
} }
} // namespace pybind } // namespace pybind
......
...@@ -507,7 +507,7 @@ class Fleet(object): ...@@ -507,7 +507,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program, executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment) export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None): def save_persistables(self, executor, dirname, main_program=None, mode=1):
""" """
saves all persistable variables from :code:`main_program` to saves all persistable variables from :code:`main_program` to
...@@ -548,7 +548,8 @@ class Fleet(object): ...@@ -548,7 +548,8 @@ class Fleet(object):
""" """
self._runtime_handle._save_persistables(executor, dirname, main_program) self._runtime_handle._save_persistables(executor, dirname, main_program,
mode)
def distributed_optimizer(self, optimizer, strategy=None): def distributed_optimizer(self, optimizer, strategy=None):
""" """
......
...@@ -521,8 +521,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -521,8 +521,7 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog) executor.run(prog)
return context.keys() return context.keys()
def _save_distributed_params(self, executor, dirname, context, def _save_distributed_params(self, executor, dirname, context, mode):
main_program):
prog = Program() prog = Program()
block = prog.global_block() block = prog.global_block()
...@@ -531,7 +530,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -531,7 +530,7 @@ class ParameterServerRuntime(RuntimeBase):
type='checkpoint_notify', type='checkpoint_notify',
attrs={ attrs={
"varname": name, "varname": name,
"is_slice": True, "mode": mode,
"slice_varnames": var_ctx.split_varnames(), "slice_varnames": var_ctx.split_varnames(),
"remote_varnames": var_ctx.split_varnames(), "remote_varnames": var_ctx.split_varnames(),
"endpoints": var_ctx.split_endpoints(), "endpoints": var_ctx.split_endpoints(),
...@@ -541,7 +540,8 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -541,7 +540,8 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog) executor.run(prog)
return context.keys() return context.keys()
def _save_distributed_persistables(self, executor, dirname, main_program): def _save_distributed_persistables(self, executor, dirname, main_program,
mode):
dense_ctx = self.compiled_strategy.get_communicator_recv_context( dense_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=1, use_origin_program=True) recv_type=1, use_origin_program=True)
...@@ -558,7 +558,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -558,7 +558,7 @@ class ParameterServerRuntime(RuntimeBase):
executor, dirname, sparse_ctx, main_program) executor, dirname, sparse_ctx, main_program)
recv_distributed_varnames = self._save_distributed_params( recv_distributed_varnames = self._save_distributed_params(
executor, dirname, distributed_ctx, main_program) executor, dirname, distributed_ctx, mode)
saved_varnames = recv_dense_varnames + list( saved_varnames = recv_dense_varnames + list(
recv_sparse_varnames) + list(recv_distributed_varnames) recv_sparse_varnames) + list(recv_distributed_varnames)
...@@ -578,6 +578,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -578,6 +578,7 @@ class ParameterServerRuntime(RuntimeBase):
executor, executor,
dirname, dirname,
main_program=None, main_program=None,
mode=0,
**kwargs): **kwargs):
""" """
This function filters out all variables with `persistable==True` from the This function filters out all variables with `persistable==True` from the
...@@ -608,7 +609,8 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -608,7 +609,8 @@ class ParameterServerRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
self._save_distributed_persistables(executor, dirname, main_program) self._save_distributed_persistables(executor, dirname, main_program,
mode)
def _ps_inference_save_inference_model(self, def _ps_inference_save_inference_model(self,
executor, executor,
...@@ -654,7 +656,8 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -654,7 +656,8 @@ class ParameterServerRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program()) program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables(executor, dirname, program) self._ps_inference_save_persistables(
executor, dirname, program, mode=0)
def _save_inference_model(self, *args, **kwargs): def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs) self._ps_inference_save_inference_model(*args, **kwargs)
......
...@@ -162,3 +162,6 @@ class LargeScaleKV(object): ...@@ -162,3 +162,6 @@ class LargeScaleKV(object):
def load(self, varname, dirname): def load(self, varname, dirname):
self.scale_kv.load(varname, dirname) self.scale_kv.load(varname, dirname)
def size(self, varname):
return self.scale_kv.size(varname)
...@@ -183,8 +183,12 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -183,8 +183,12 @@ class TestPSPassWithBow(unittest.TestCase):
from paddle.fluid.communicator import LargeScaleKV from paddle.fluid.communicator import LargeScaleKV
kv = LargeScaleKV() kv = LargeScaleKV()
kv.save("__emb__.block0", kv.save("__emb__.block0",
os.path.join(model_dir, "__emb__", "__emb__.block0")) os.path.join(model_dir, "__emb__", "__emb__.block0"))
kv.size("__emb__.block0")
fluid.framework.switch_main_program(fluid.Program()) fluid.framework.switch_main_program(fluid.Program())
fleet.init_server(model_dir) fleet.init_server(model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册