未验证 提交 202bfab1 编写于 作者: T tangwei12 提交者: GitHub

Feature/large scale kv save base/delta (#27470)

* add size method for large scale

* add large scale UT

* add ut for checkpoint
上级 aa3b4ed7
......@@ -56,7 +56,7 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op scale_op)
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op checkpoint_notify_op scale_op )
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
......@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname,
const std::string& varname,
const int mode,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(varname);
......
......@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
......
......@@ -422,6 +422,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname,
const std::string& varname,
const int mode,
int64_t time_out) {
const auto ch = GetChannel(ep);
......@@ -435,6 +436,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(varname);
req.set_table_name(std::to_string(mode));
req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method);
......
......@@ -258,7 +258,7 @@ class GRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
......
......@@ -398,12 +398,13 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
std::string table_name = request_->TableName();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir);
trainer_id, checkpoint_dir, table_name);
Finish(reply_, &responder_);
}
......
......@@ -245,6 +245,7 @@ struct VALUE {
std::vector<std::string> names_;
int count_;
bool seen_after_last_save_;
int unseen_days_;
bool is_entry_;
std::vector<std::vector<float>> values_;
......@@ -321,6 +322,7 @@ class ValueBlock {
auto value = new VALUE(value_names_);
value->set(values);
value->seen_after_last_save_ = true;
value->count_ = count;
values_[id] = value;
}
......@@ -589,9 +591,9 @@ class SparseVariable {
}
}
void Save(const std::string &dirname) {
void Save(const std::string &dirname, const int mode = 0) {
rwlock_->WRLock();
VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " begin";
VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " begin";
MkDirRecursively(dirname.c_str());
......@@ -600,22 +602,15 @@ class SparseVariable {
auto filename = string::Sprintf("%s/%s", dirname, value_name);
filenames.push_back(filename);
}
SaveToSelectedRows(filenames, meta_.value_names);
// // save sparse to text
// std::vector<std::string> txt_filenames;
// for (auto &value_name : meta_.value_names) {
// auto filename = string::Sprintf("%s/%s.txt", dirname, value_name);
// txt_filenames.push_back(filename);
// }
// SaveToText(txt_filenames, meta_.value_names);
VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " done";
SaveToSelectedRows(filenames, meta_.value_names, mode);
VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " done";
rwlock_->UNLock();
}
void SaveToSelectedRows(const std::vector<std::string> &filenames,
const std::vector<std::string> &valuenames) {
const std::vector<std::string> &valuenames,
const int mode) {
for (auto &value_name : valuenames) {
auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(),
value_name);
......@@ -629,14 +624,34 @@ class SparseVariable {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
int64_t ids_num = 0;
std::vector<int64_t> ids;
for (auto &block : shard_blocks_) {
ids_num += block->values_.size();
for (auto value : block->values_) {
if (mode == 0) {
ids.push_back(value.first);
} else {
bool id_need_save = false;
// save all params
if (mode == 1) {
id_need_save = true;
} else {
id_need_save = value.second->seen_after_last_save_;
}
if (id_need_save) {
ids.push_back(value.first);
}
value.second->seen_after_last_save_ = false;
}
}
}
VLOG(3) << "save " << ids.size() << " feasigns for " << meta_.name
<< " with mode: " << mode;
std::vector<std::shared_ptr<framework::Variable>> variables;
std::vector<float *> tensors;
std::vector<int64_t> ids;
std::vector<int64_t> dims;
for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
......@@ -645,7 +660,7 @@ class SparseVariable {
auto *slr = var->GetMutable<framework::SelectedRows>();
auto *src_t = slr->mutable_value();
src_t->Resize({ids_num, dim});
src_t->Resize({static_cast<int64_t>(ids.size()), dim});
auto *value = src_t->mutable_data<float>(place);
dims.push_back(dim);
......@@ -653,20 +668,17 @@ class SparseVariable {
tensors.push_back(value);
}
int64_t offset = 0;
for (auto &block : shard_blocks_) {
for (auto value : block->values_) {
ids.push_back(value.first);
std::vector<std::vector<float> *> vss = value.second->get(valuenames);
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(),
sizeof(float) * dims[i]);
}
std::vector<std::vector<std::vector<float> *>> values;
Get(ids, valuenames, &values);
offset += 1;
int64_t offset = 0;
for (auto &vss : values) {
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(),
sizeof(float) * dims[i]);
}
offset += 1;
}
for (auto &var : variables) {
......
......@@ -274,13 +274,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name;
VLOG(4) << "receive save var " << varname << " with path " << out_var_name
<< " mode " << table_name;
int mode = std::stoi(table_name);
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
ins->Get(varname)->Save(out_var_name, mode);
return true;
}
......
......@@ -87,7 +87,8 @@ class RPCClient {
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <stdlib.h>
#include <unistd.h>
#include <chrono> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
......@@ -26,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
......@@ -35,6 +37,7 @@ namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table_read);
USE_NO_KERNEL_OP(checkpoint_notify);
USE_OP(scale);
std::unique_ptr<distributed::RPCServer> g_rpc_service;
......@@ -122,7 +125,7 @@ void StartServer(const std::string& rpc_name) {
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
distributed::HeartBeatMonitor::Init(2, true, "w@grad");
// distributed::HeartBeatMonitor::Init(1, true, "w@grad");
g_req_handler->SetRPCServer(g_rpc_service.get());
......@@ -232,3 +235,110 @@ TEST(SENDANDRECV, CPU) {
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
void StartCheckpointServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
std::vector<distributed::SparseMeta> metas;
auto meta = distributed::SparseMeta();
meta.name = "embedding.block0";
meta.value_names = {"Param"};
meta.value_dims = {64};
meta.mode = distributed::Mode::training;
meta.grad_name = "embedding@Grad";
meta.cached_varnames = {"kSparseIds"};
meta.initializer_attrs = {"fill_constant&1.0"};
meta.entry = "none";
metas.push_back(meta);
distributed::LargeScaleKV::Init(metas);
auto* ins = distributed::LargeScaleKV::GetInstance();
ins->Get("embedding.block0")->Init({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(LARGE_SCALE_CHECKPOINT, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
g_req_handler.reset(new distributed::RequestCheckpointHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartCheckpointServer,
distributed::kRequestCheckpoint);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
auto save_path =
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base",
"embedding", "embedding.block0");
int mode = 0;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
save_path =
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta",
"embedding", "embedding.block0");
mode = 1;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
paddle::framework::AttributeMap attrs;
std::vector<std::string> eps = {ep};
attrs["endpoints"] = eps;
attrs["dirname"] = std::string("/tmp/large_scale_table/delta1");
attrs["varname"] = std::string("embedding");
attrs["mode"] = 2;
std::vector<std::string> slices = {"embedding.block0"};
attrs["slice_varnames"] = slices;
std::vector<std::string> remotes = {"embedding.block0"};
attrs["remote_varnames"] = remotes;
auto ops =
framework::OpRegistry::CreateOp("checkpoint_notify", {}, {}, attrs, true);
ops->Run(scope, place);
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
......@@ -42,8 +42,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoints");
std::string dirname = Attr<std::string>("dirname");
std::string varname = Attr<std::string>("varname");
auto is_slice = Attr<bool>("is_slice");
VLOG(1) << "is_slice: " << is_slice;
auto mode = Attr<int>("mode");
if (mode != 0 && mode != 1 && mode != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"mode expected in [0/1/2], but got %d", mode));
}
std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames");
......@@ -58,11 +62,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
auto save_path =
string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path,
remote_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i],
mode);
VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i];
<< " and var:" << slice_varnames[i] << " to " << epmap[i]
<< " with mode " << mode;
}
PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true,
......@@ -85,9 +90,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<bool>(
"is_slice",
"is_slice=True means the var has been slice by parameter server");
AddAttr<int>("mode", "mode=0/1/2 means nothing/save base/save delta")
.SetDefault(0);
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
......
......@@ -109,10 +109,15 @@ void BindLargeScaleKV(py::module* m) {
auto* sparse_variable = self.Get(table_name);
sparse_variable->Load(dir);
})
.def("save", [](LargeScaleKV& self, const std::string& table_name,
const std::string& dir) {
.def("save",
[](LargeScaleKV& self, const std::string& table_name,
const std::string& dir) {
auto* sparse_variable = self.Get(table_name);
sparse_variable->Save(dir);
})
.def("size", [](LargeScaleKV& self, const std::string& table_name) {
auto* sparse_variable = self.Get(table_name);
sparse_variable->Save(dir);
return sparse_variable->Size();
});
}
} // namespace pybind
......
......@@ -507,7 +507,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None):
def save_persistables(self, executor, dirname, main_program=None, mode=1):
"""
saves all persistable variables from :code:`main_program` to
......@@ -548,7 +548,8 @@ class Fleet(object):
"""
self._runtime_handle._save_persistables(executor, dirname, main_program)
self._runtime_handle._save_persistables(executor, dirname, main_program,
mode)
def distributed_optimizer(self, optimizer, strategy=None):
"""
......
......@@ -521,8 +521,7 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog)
return context.keys()
def _save_distributed_params(self, executor, dirname, context,
main_program):
def _save_distributed_params(self, executor, dirname, context, mode):
prog = Program()
block = prog.global_block()
......@@ -531,7 +530,7 @@ class ParameterServerRuntime(RuntimeBase):
type='checkpoint_notify',
attrs={
"varname": name,
"is_slice": True,
"mode": mode,
"slice_varnames": var_ctx.split_varnames(),
"remote_varnames": var_ctx.split_varnames(),
"endpoints": var_ctx.split_endpoints(),
......@@ -541,7 +540,8 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog)
return context.keys()
def _save_distributed_persistables(self, executor, dirname, main_program):
def _save_distributed_persistables(self, executor, dirname, main_program,
mode):
dense_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=1, use_origin_program=True)
......@@ -558,7 +558,7 @@ class ParameterServerRuntime(RuntimeBase):
executor, dirname, sparse_ctx, main_program)
recv_distributed_varnames = self._save_distributed_params(
executor, dirname, distributed_ctx, main_program)
executor, dirname, distributed_ctx, mode)
saved_varnames = recv_dense_varnames + list(
recv_sparse_varnames) + list(recv_distributed_varnames)
......@@ -578,6 +578,7 @@ class ParameterServerRuntime(RuntimeBase):
executor,
dirname,
main_program=None,
mode=0,
**kwargs):
"""
This function filters out all variables with `persistable==True` from the
......@@ -608,7 +609,8 @@ class ParameterServerRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
self._save_distributed_persistables(executor, dirname, main_program)
self._save_distributed_persistables(executor, dirname, main_program,
mode)
def _ps_inference_save_inference_model(self,
executor,
......@@ -654,7 +656,8 @@ class ParameterServerRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables(executor, dirname, program)
self._ps_inference_save_persistables(
executor, dirname, program, mode=0)
def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs)
......
......@@ -162,3 +162,6 @@ class LargeScaleKV(object):
def load(self, varname, dirname):
self.scale_kv.load(varname, dirname)
def size(self, varname):
return self.scale_kv.size(varname)
......@@ -183,8 +183,12 @@ class TestPSPassWithBow(unittest.TestCase):
from paddle.fluid.communicator import LargeScaleKV
kv = LargeScaleKV()
kv.save("__emb__.block0",
os.path.join(model_dir, "__emb__", "__emb__.block0"))
kv.size("__emb__.block0")
fluid.framework.switch_main_program(fluid.Program())
fleet.init_server(model_dir)
shutil.rmtree(model_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册