未验证 提交 d479ae17 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】Support local save sparse param (#30175)

* add save tensor support
Co-authored-by: NseiriosPlus <tangwei12@baidu.com>
上级 113810c5
...@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, ...@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
} }
} }
void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->recv_and_save_table(table_id, path);
if (ret != 0) {
LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed";
}
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) { void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto* communicator = Communicator::GetInstance(); auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->print_table_stat(table_id); auto ret = communicator->_worker_ptr->print_table_stat(table_id);
......
...@@ -198,6 +198,10 @@ class FleetWrapper { ...@@ -198,6 +198,10 @@ class FleetWrapper {
// mode = 1, save delta feature, which means save diff // mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path, void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode); const int mode);
// recv table from server and save it in LodTensor
void RecvAndSaveTable(const uint64_t table_id, const std::string& path);
// clear all models, release their memory // clear all models, release their memory
void ClearModel(); void ClearModel();
// clear one table // clear one table
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -21,6 +22,7 @@ ...@@ -21,6 +22,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
const static int max_port = 65535; const static int max_port = 65535;
...@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000, ...@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial( ...@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
return fut; return fut;
} }
int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
const std::string &path) {
// get var information
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).accessor().fea_dim();
var_shape = worker_param.downpour_table_param(i).accessor().embedx_dim();
break;
}
}
PADDLE_ENFORCE_NE(
var_name, "",
platform::errors::InvalidArgument(
"Cannot find table id %d to save variables.", table_id));
std::string var_store = string::Sprintf("%s", path);
MkDirRecursively(var_store.c_str());
// pull sparse from server
std::vector<float> save_huge_vec(var_num * var_shape);
std::vector<uint64_t> save_key(var_num);
std::vector<float *> save_vec;
for (size_t i = 0; i < save_key.size(); ++i) {
save_key[i] = i;
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
auto status = pull_sparse((float **)save_vec.data(), table_id,
save_key.data(), save_key.size());
status.wait();
// create lod tensor
std::shared_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
auto place = platform::CPUPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::Variable *var = scope->Var(var_name);
framework::LoDTensor *var_tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int64_t> vec_dim = {var_num, var_shape};
var_tensor->Resize(framework::make_ddim(vec_dim));
// copy and save
float *tensor_data = var_tensor->mutable_data<float>(place);
memcpy(tensor_data, save_huge_vec.data(),
var_num * var_shape * sizeof(float));
std::string file_name = string::Sprintf("%s/%s", var_store, var_name);
std::ofstream fout(file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", file_name));
framework::SerializeToStream(fout, *var_tensor, dev_ctx);
fout.close();
return 0;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include "brpc/controller.h" #include "brpc/controller.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient { ...@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient {
virtual std::future<int32_t> send_client2client_msg( virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override; int msg_type, int to_client_id, const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
private: private:
virtual int32_t initialize() override; virtual int32_t initialize() override;
......
...@@ -134,6 +134,11 @@ class PSClient { ...@@ -134,6 +134,11 @@ class PSClient {
virtual std::future<int32_t> push_global_step(int table_id, virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data, int64_t *total_send_data,
void *done) = 0; void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path) = 0;
virtual void finalize_worker() = 0; virtual void finalize_worker() = 0;
// client to client, 消息发送 // client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type, virtual std::future<int32_t> send_client2client_msg(int msg_type,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX "_txt"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname, ...@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname,
VLOG(0) << "sparse table save: " << dirname << " mode: " << mode; VLOG(0) << "sparse table save: " << dirname << " mode: " << mode;
auto varname = _config.common().table_name(); auto varname = _config.common().table_name();
std::string var_store = string::Sprintf("%s/%s", dirname, varname); std::string var_store =
string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
MkDirRecursively(var_store.c_str()); MkDirRecursively(var_store.c_str());
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
......
...@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) {
.def("pull_dense_params", &FleetWrapper::PullDenseVarsSync) .def("pull_dense_params", &FleetWrapper::PullDenseVarsSync)
.def("save_all_model", &FleetWrapper::SaveModel) .def("save_all_model", &FleetWrapper::SaveModel)
.def("save_one_model", &FleetWrapper::SaveModelOneTable) .def("save_one_model", &FleetWrapper::SaveModelOneTable)
.def("recv_and_save_model", &FleetWrapper::RecvAndSaveTable)
.def("sparse_table_stat", &FleetWrapper::PrintTableStat) .def("sparse_table_stat", &FleetWrapper::PrintTableStat)
.def("stop_server", &FleetWrapper::StopServer) .def("stop_server", &FleetWrapper::StopServer)
.def("stop_worker", &FleetWrapper::FinalizeWorker) .def("stop_worker", &FleetWrapper::FinalizeWorker)
......
...@@ -545,7 +545,7 @@ class Fleet(object): ...@@ -545,7 +545,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program, executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment) export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None, mode=1): def save_persistables(self, executor, dirname, main_program=None, mode=0):
""" """
saves all persistable tensors from :code:`main_program` to saves all persistable tensors from :code:`main_program` to
......
...@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_main = compiled_config.origin_main_program.clone() _main = compiled_config.origin_main_program.clone()
_startup = compiled_config.origin_startup_program.clone() _startup = compiled_config.origin_startup_program.clone()
if not compiled_config.is_geo_mode():
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
_add_lr_decay_table_pass( _add_lr_decay_table_pass(
_main, compiled_config, _main, compiled_config,
self.user_defined_strategy.a_sync_configs["lr_decay_steps"]) self.user_defined_strategy.a_sync_configs["lr_decay_steps"])
if not compiled_config.is_geo_mode():
# for main program # for main program
_main = worker.delete_optimizer_pass(_main, compiled_config) _main = worker.delete_optimizer_pass(_main, compiled_config)
_main = worker.distributed_ops_pass(_main, compiled_config) _main = worker.distributed_ops_pass(_main, compiled_config)
......
...@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid return is_valid
def _save_sparse_params(self, executor, dirname, context, main_program): def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
distributed_varnames = get_sparse_tablenames(
self.compiled_strategy.origin_main_program, True)
values = [] values = []
for id, names in context.items(): for id, names in context.items():
if names not in distributed_varnames:
# only save sparse param to local
self._worker.recv_and_save_model(id, dirname)
# save sparse & distributed param on server
self._worker.save_one_model(id, dirname, mode)
values.extend(names) values.extend(names)
self._worker.save_one_model(id, dirname, 0)
return values return values
def _save_distributed_persistables(self, executor, dirname, main_program, def _save_distributed_persistables(self,
mode): executor,
dirname,
main_program,
mode=0):
denses = self.compiled_strategy.get_the_one_recv_context( denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True, is_dense=True,
...@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase):
split_dense_table=self.role_maker._is_heter_parameter_server_mode, split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True) use_origin_program=True)
recv_sparse_varnames = self._save_sparse_params(executor, dirname, sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
sparses, main_program) main_program, mode)
recv_dense_varnames = [] recv_dense_varnames = []
for id, names in denses.items(): for id, names in denses.items():
recv_dense_varnames.extend(names) recv_dense_varnames.extend(names)
saved_varnames = recv_sparse_varnames saved_varnames = sparse_varnames
remaining_vars = list( remaining_vars = list(
filter( filter(
...@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
# Todo(MrChengmo): Save optimizer status
self._save_distributed_persistables(executor, dirname, main_program, self._save_distributed_persistables(executor, dirname, main_program,
mode) mode)
...@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program()) program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables( self._ps_inference_save_persistables(executor, dirname, program)
executor, dirname, program, mode=0)
def _save_inference_model(self, *args, **kwargs): def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs) self._ps_inference_save_inference_model(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册