From 202bfab1be0f3cbaa8f5b7117502a532a332eba0 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 15 Oct 2020 17:13:21 +0800 Subject: [PATCH] Feature/large scale kv save base/delta (#27470) * add size method for large scale * add large scale UT * add ut for checkpoint --- .../operators/distributed/CMakeLists.txt | 2 +- .../operators/distributed/brpc/brpc_client.cc | 1 + .../operators/distributed/brpc/brpc_client.h | 2 +- .../operators/distributed/grpc/grpc_client.cc | 2 + .../operators/distributed/grpc/grpc_client.h | 2 +- .../operators/distributed/grpc/grpc_server.cc | 3 +- .../operators/distributed/large_scale_kv.h | 70 ++++++----- .../distributed/request_handler_impl.cc | 10 +- .../fluid/operators/distributed/rpc_client.h | 3 +- .../operators/distributed/rpc_server_test.cc | 112 +++++++++++++++++- .../distributed_ops/checkpoint_notify_op.cc | 20 ++-- paddle/fluid/pybind/communicator_py.cc | 11 +- .../distributed/fleet/base/fleet_base.py | 5 +- .../fleet/runtime/parameter_server_runtime.py | 17 +-- python/paddle/fluid/communicator.py | 3 + .../tests/unittests/test_dist_fleet_ps2.py | 4 + 16 files changed, 207 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index e584e02508..47fbb42fd6 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -56,7 +56,7 @@ endif() cc_test(rpc_server_test SRCS rpc_server_test.cc - DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op scale_op) + DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op checkpoint_notify_op scale_op ) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc index cb93b8d910..b2a26089c8 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.cc @@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dirname, const std::string& varname, + const int mode, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(varname); diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h index 2ea90d560f..91f94b4c9d 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.h @@ -103,7 +103,7 @@ class BRPCClient : public RPCClient { VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, + const std::string& varname, const int mode, int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 9fd828bfa5..0320ef6595 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -422,6 +422,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dirname, const std::string& varname, + const int mode, int64_t time_out) { const auto ch = GetChannel(ep); @@ -435,6 +436,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, sendrecv::VariableMessage req; req.set_varname(varname); + req.set_table_name(std::to_string(mode)); req.set_out_varname(dirname); platform::RecordRPCEvent record_event(method); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index 22ca74a67e..7b269f4d80 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -258,7 +258,7 @@ class GRPCClient : public RPCClient { VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, + const std::string& varname, const int mode, int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncDistributeNotify( diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index e4216db186..912520d782 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -398,12 +398,13 @@ class RequestCheckpointNotify final : public RequestBase { std::string checkpoint_notify = request_->Varname(); std::string checkpoint_dir = request_->OutVarname(); int trainer_id = request_->GetTrainerId(); + std::string table_name = request_->TableName(); VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify << ", dir: " << checkpoint_dir; request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, - trainer_id, checkpoint_dir); + trainer_id, checkpoint_dir, table_name); Finish(reply_, &responder_); } diff --git a/paddle/fluid/operators/distributed/large_scale_kv.h b/paddle/fluid/operators/distributed/large_scale_kv.h index b4388c0002..52b76b7bfe 100644 --- a/paddle/fluid/operators/distributed/large_scale_kv.h +++ b/paddle/fluid/operators/distributed/large_scale_kv.h @@ -245,6 +245,7 @@ struct VALUE { std::vector names_; int count_; + bool seen_after_last_save_; int unseen_days_; bool is_entry_; std::vector> values_; @@ -321,6 +322,7 @@ class ValueBlock { auto value = new VALUE(value_names_); value->set(values); + value->seen_after_last_save_ = true; value->count_ = count; values_[id] = value; } @@ -589,9 +591,9 @@ class SparseVariable { } } - void Save(const std::string &dirname) { + void Save(const std::string &dirname, const int mode = 0) { rwlock_->WRLock(); - VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " begin"; + VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " begin"; MkDirRecursively(dirname.c_str()); @@ -600,22 +602,15 @@ class SparseVariable { auto filename = string::Sprintf("%s/%s", dirname, value_name); filenames.push_back(filename); } - SaveToSelectedRows(filenames, meta_.value_names); - // // save sparse to text - // std::vector txt_filenames; - // for (auto &value_name : meta_.value_names) { - // auto filename = string::Sprintf("%s/%s.txt", dirname, value_name); - // txt_filenames.push_back(filename); - // } - // SaveToText(txt_filenames, meta_.value_names); - - VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " done"; + SaveToSelectedRows(filenames, meta_.value_names, mode); + VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " done"; rwlock_->UNLock(); } void SaveToSelectedRows(const std::vector &filenames, - const std::vector &valuenames) { + const std::vector &valuenames, + const int mode) { for (auto &value_name : valuenames) { auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(), value_name); @@ -629,14 +624,34 @@ class SparseVariable { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - int64_t ids_num = 0; + std::vector ids; + for (auto &block : shard_blocks_) { - ids_num += block->values_.size(); + for (auto value : block->values_) { + if (mode == 0) { + ids.push_back(value.first); + } else { + bool id_need_save = false; + // save all params + if (mode == 1) { + id_need_save = true; + } else { + id_need_save = value.second->seen_after_last_save_; + } + + if (id_need_save) { + ids.push_back(value.first); + } + value.second->seen_after_last_save_ = false; + } + } } + VLOG(3) << "save " << ids.size() << " feasigns for " << meta_.name + << " with mode: " << mode; + std::vector> variables; std::vector tensors; - std::vector ids; std::vector dims; for (int i = 0; i < static_cast(filenames.size()); i++) { @@ -645,7 +660,7 @@ class SparseVariable { auto *slr = var->GetMutable(); auto *src_t = slr->mutable_value(); - src_t->Resize({ids_num, dim}); + src_t->Resize({static_cast(ids.size()), dim}); auto *value = src_t->mutable_data(place); dims.push_back(dim); @@ -653,20 +668,17 @@ class SparseVariable { tensors.push_back(value); } - int64_t offset = 0; - for (auto &block : shard_blocks_) { - for (auto value : block->values_) { - ids.push_back(value.first); - std::vector *> vss = value.second->get(valuenames); - - for (int i = 0; i < static_cast(vss.size()); i++) { - auto &vs = vss[i]; - std::memcpy(tensors[i] + offset * dims[i], vs->data(), - sizeof(float) * dims[i]); - } + std::vector *>> values; + Get(ids, valuenames, &values); - offset += 1; + int64_t offset = 0; + for (auto &vss : values) { + for (int i = 0; i < static_cast(vss.size()); i++) { + auto &vs = vss[i]; + std::memcpy(tensors[i] + offset * dims[i], vs->data(), + sizeof(float) * dims[i]); } + offset += 1; } for (auto &var : variables) { diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 0d67fc0021..8c4f2ef57a 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -274,13 +274,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname, const int trainer_id, const std::string &out_var_name, const std::string &table_name) { - VLOG(4) << "receive save var " << varname << " with path " << out_var_name; + VLOG(4) << "receive save var " << varname << " with path " << out_var_name + << " mode " << table_name; + + int mode = std::stoi(table_name); auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Save(out_var_name); - // auto checkpoint_op = BuildCheckpointOp(varname, out_var_name); - // paddle::platform::CPUPlace cpu_place; - // checkpoint_op->Run(*scope_, cpu_place); + ins->Get(varname)->Save(out_var_name, mode); return true; } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 6a6a795a46..2c756a6f71 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -87,7 +87,8 @@ class RPCClient { virtual VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0; + const std::string& varname, const int mode, + int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncDistributeNotify( const std::string& ep, const platform::DeviceContext& ctx, diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index b6d4d59485..f592854000 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include // NOLINT #include #include #include // NOLINT @@ -26,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" +#include "paddle/fluid/operators/distributed/large_scale_kv.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_server.h" @@ -35,6 +37,7 @@ namespace platform = paddle::platform; namespace distributed = paddle::operators::distributed; USE_NO_KERNEL_OP(lookup_sparse_table_read); +USE_NO_KERNEL_OP(checkpoint_notify); USE_OP(scale); std::unique_ptr g_rpc_service; @@ -122,7 +125,7 @@ void StartServer(const std::string& rpc_name) { g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); - distributed::HeartBeatMonitor::Init(2, true, "w@grad"); + // distributed::HeartBeatMonitor::Init(1, true, "w@grad"); g_req_handler->SetRPCServer(g_rpc_service.get()); @@ -232,3 +235,110 @@ TEST(SENDANDRECV, CPU) { g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); } + +void StartCheckpointServer(const std::string& rpc_name) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + + std::vector metas; + + auto meta = distributed::SparseMeta(); + meta.name = "embedding.block0"; + meta.value_names = {"Param"}; + meta.value_dims = {64}; + meta.mode = distributed::Mode::training; + meta.grad_name = "embedding@Grad"; + meta.cached_varnames = {"kSparseIds"}; + meta.initializer_attrs = {"fill_constant&1.0"}; + meta.entry = "none"; + + metas.push_back(meta); + distributed::LargeScaleKV::Init(metas); + + auto* ins = distributed::LargeScaleKV::GetInstance(); + ins->Get("embedding.block0")->Init({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + std::unordered_map> + prefetch_var_name_to_prepared; + + g_req_handler->SetProgram(&program); + g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); + g_req_handler->SetDevCtx(&ctx); + g_req_handler->SetScope(&scope); + g_req_handler->SetExecutor(&exe); + + g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); + + g_req_handler->SetRPCServer(g_rpc_service.get()); + + std::thread server_thread( + std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); + + server_thread.join(); +} + +TEST(LARGE_SCALE_CHECKPOINT, CPU) { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + + g_req_handler.reset(new distributed::RequestCheckpointHandler( + distributed::DistributedMode::kAsync)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); + + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(0); + + PADDLE_ENFORCE_NE(client, nullptr, + platform::errors::InvalidArgument( + "Client Start Fail, Check Your Code & Env")); + + std::thread server_thread(StartCheckpointServer, + distributed::kRequestCheckpoint); + g_rpc_service->WaitServerReady(); + + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + + auto save_path = + paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base", + "embedding", "embedding.block0"); + int mode = 0; + client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode); + client->Wait(); + + save_path = + paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta", + "embedding", "embedding.block0"); + mode = 1; + client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode); + client->Wait(); + + paddle::framework::AttributeMap attrs; + + std::vector eps = {ep}; + attrs["endpoints"] = eps; + attrs["dirname"] = std::string("/tmp/large_scale_table/delta1"); + attrs["varname"] = std::string("embedding"); + attrs["mode"] = 2; + std::vector slices = {"embedding.block0"}; + attrs["slice_varnames"] = slices; + std::vector remotes = {"embedding.block0"}; + attrs["remote_varnames"] = remotes; + + auto ops = + framework::OpRegistry::CreateOp("checkpoint_notify", {}, {}, attrs, true); + ops->Run(scope, place); + + g_rpc_service->ShutDown(); + server_thread.join(); + LOG(INFO) << "begin reset"; + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); +} diff --git a/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc index abc8d91284..051d9d65c7 100644 --- a/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc @@ -42,8 +42,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { Attr>("endpoints"); std::string dirname = Attr("dirname"); std::string varname = Attr("varname"); - auto is_slice = Attr("is_slice"); - VLOG(1) << "is_slice: " << is_slice; + auto mode = Attr("mode"); + + if (mode != 0 && mode != 1 && mode != 2) { + PADDLE_THROW(platform::errors::InvalidArgument( + "mode expected in [0/1/2], but got %d", mode)); + } std::vector slice_varnames = Attr>("slice_varnames"); @@ -58,11 +62,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { auto save_path = string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]); - rpc_client->AsyncCheckpointNotify(epmap[i], save_path, - remote_varnames[i]); + rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i], + mode); VLOG(3) << "checkpoint notify sending with path: " << save_path - << " and var:" << slice_varnames[i] << " to " << epmap[i]; + << " and var:" << slice_varnames[i] << " to " << epmap[i] + << " with mode " << mode; } PADDLE_ENFORCE_EQ( rpc_client->Wait(), true, @@ -85,9 +90,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { "slice_varnames", "(string vector) the slice vars need to be saved"); AddAttr>( "remote_varnames", "(string vector) the slice vars need to be saved"); - AddAttr( - "is_slice", - "is_slice=True means the var has been slice by parameter server"); + AddAttr("mode", "mode=0/1/2 means nothing/save base/save delta") + .SetDefault(0); AddComment(R"DOC( CheckpointNotify operator This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at diff --git a/paddle/fluid/pybind/communicator_py.cc b/paddle/fluid/pybind/communicator_py.cc index 6ac37a85c2..07ba706167 100644 --- a/paddle/fluid/pybind/communicator_py.cc +++ b/paddle/fluid/pybind/communicator_py.cc @@ -109,10 +109,15 @@ void BindLargeScaleKV(py::module* m) { auto* sparse_variable = self.Get(table_name); sparse_variable->Load(dir); }) - .def("save", [](LargeScaleKV& self, const std::string& table_name, - const std::string& dir) { + .def("save", + [](LargeScaleKV& self, const std::string& table_name, + const std::string& dir) { + auto* sparse_variable = self.Get(table_name); + sparse_variable->Save(dir); + }) + .def("size", [](LargeScaleKV& self, const std::string& table_name) { auto* sparse_variable = self.Get(table_name); - sparse_variable->Save(dir); + return sparse_variable->Size(); }); } } // namespace pybind diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 56a67599a4..c46911da0f 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -507,7 +507,7 @@ class Fleet(object): executor, dirname, feeded_var_names, target_vars, main_program, export_for_deployment) - def save_persistables(self, executor, dirname, main_program=None): + def save_persistables(self, executor, dirname, main_program=None, mode=1): """ saves all persistable variables from :code:`main_program` to @@ -548,7 +548,8 @@ class Fleet(object): """ - self._runtime_handle._save_persistables(executor, dirname, main_program) + self._runtime_handle._save_persistables(executor, dirname, main_program, + mode) def distributed_optimizer(self, optimizer, strategy=None): """ diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 415e091680..887209d9de 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -521,8 +521,7 @@ class ParameterServerRuntime(RuntimeBase): executor.run(prog) return context.keys() - def _save_distributed_params(self, executor, dirname, context, - main_program): + def _save_distributed_params(self, executor, dirname, context, mode): prog = Program() block = prog.global_block() @@ -531,7 +530,7 @@ class ParameterServerRuntime(RuntimeBase): type='checkpoint_notify', attrs={ "varname": name, - "is_slice": True, + "mode": mode, "slice_varnames": var_ctx.split_varnames(), "remote_varnames": var_ctx.split_varnames(), "endpoints": var_ctx.split_endpoints(), @@ -541,7 +540,8 @@ class ParameterServerRuntime(RuntimeBase): executor.run(prog) return context.keys() - def _save_distributed_persistables(self, executor, dirname, main_program): + def _save_distributed_persistables(self, executor, dirname, main_program, + mode): dense_ctx = self.compiled_strategy.get_communicator_recv_context( recv_type=1, use_origin_program=True) @@ -558,7 +558,7 @@ class ParameterServerRuntime(RuntimeBase): executor, dirname, sparse_ctx, main_program) recv_distributed_varnames = self._save_distributed_params( - executor, dirname, distributed_ctx, main_program) + executor, dirname, distributed_ctx, mode) saved_varnames = recv_dense_varnames + list( recv_sparse_varnames) + list(recv_distributed_varnames) @@ -578,6 +578,7 @@ class ParameterServerRuntime(RuntimeBase): executor, dirname, main_program=None, + mode=0, **kwargs): """ This function filters out all variables with `persistable==True` from the @@ -608,7 +609,8 @@ class ParameterServerRuntime(RuntimeBase): "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" ) - self._save_distributed_persistables(executor, dirname, main_program) + self._save_distributed_persistables(executor, dirname, main_program, + mode) def _ps_inference_save_inference_model(self, executor, @@ -654,7 +656,8 @@ class ParameterServerRuntime(RuntimeBase): program = Program.parse_from_string(program_desc_str) program._copy_dist_param_info_from(fluid.default_main_program()) - self._ps_inference_save_persistables(executor, dirname, program) + self._ps_inference_save_persistables( + executor, dirname, program, mode=0) def _save_inference_model(self, *args, **kwargs): self._ps_inference_save_inference_model(*args, **kwargs) diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index 814a70a10e..b203e2a80b 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -162,3 +162,6 @@ class LargeScaleKV(object): def load(self, varname, dirname): self.scale_kv.load(varname, dirname) + + def size(self, varname): + return self.scale_kv.size(varname) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py index 218eb77d0b..d9ef1cf50c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py @@ -183,8 +183,12 @@ class TestPSPassWithBow(unittest.TestCase): from paddle.fluid.communicator import LargeScaleKV kv = LargeScaleKV() + kv.save("__emb__.block0", os.path.join(model_dir, "__emb__", "__emb__.block0")) + + kv.size("__emb__.block0") + fluid.framework.switch_main_program(fluid.Program()) fleet.init_server(model_dir) shutil.rmtree(model_dir) -- GitLab