未验证 提交 d479ae17 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】Support local save sparse param (#30175)

* add save tensor support
Co-authored-by: NseiriosPlus <tangwei12@baidu.com>
上级 113810c5
......@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
}
}
void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->recv_and_save_table(table_id, path);
if (ret != 0) {
LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed";
}
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->print_table_stat(table_id);
......
......@@ -198,6 +198,10 @@ class FleetWrapper {
// mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// recv table from server and save it in LodTensor
void RecvAndSaveTable(const uint64_t table_id, const std::string& path);
// clear all models, release their memory
void ClearModel();
// clear one table
......
......@@ -14,6 +14,7 @@
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
......@@ -21,6 +22,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
const static int max_port = 65535;
......@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
......@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
return fut;
}
int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
const std::string &path) {
// get var information
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).accessor().fea_dim();
var_shape = worker_param.downpour_table_param(i).accessor().embedx_dim();
break;
}
}
PADDLE_ENFORCE_NE(
var_name, "",
platform::errors::InvalidArgument(
"Cannot find table id %d to save variables.", table_id));
std::string var_store = string::Sprintf("%s", path);
MkDirRecursively(var_store.c_str());
// pull sparse from server
std::vector<float> save_huge_vec(var_num * var_shape);
std::vector<uint64_t> save_key(var_num);
std::vector<float *> save_vec;
for (size_t i = 0; i < save_key.size(); ++i) {
save_key[i] = i;
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
auto status = pull_sparse((float **)save_vec.data(), table_id,
save_key.data(), save_key.size());
status.wait();
// create lod tensor
std::shared_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
auto place = platform::CPUPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::Variable *var = scope->Var(var_name);
framework::LoDTensor *var_tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int64_t> vec_dim = {var_num, var_shape};
var_tensor->Resize(framework::make_ddim(vec_dim));
// copy and save
float *tensor_data = var_tensor->mutable_data<float>(place);
memcpy(tensor_data, save_huge_vec.data(),
var_num * var_shape * sizeof(float));
std::string file_name = string::Sprintf("%s/%s", var_store, var_name);
std::ofstream fout(file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", file_name));
framework::SerializeToStream(fout, *var_tensor, dev_ctx);
fout.close();
return 0;
}
} // namespace distributed
} // namespace paddle
......@@ -22,6 +22,9 @@
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace distributed {
......@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient {
virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
private:
virtual int32_t initialize() override;
......
......@@ -134,6 +134,11 @@ class PSClient {
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path) = 0;
virtual void finalize_worker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type,
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX "_txt"
namespace paddle {
namespace distributed {
......@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname,
VLOG(0) << "sparse table save: " << dirname << " mode: " << mode;
auto varname = _config.common().table_name();
std::string var_store = string::Sprintf("%s/%s", dirname, varname);
std::string var_store =
string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
MkDirRecursively(var_store.c_str());
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
......
......@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) {
.def("pull_dense_params", &FleetWrapper::PullDenseVarsSync)
.def("save_all_model", &FleetWrapper::SaveModel)
.def("save_one_model", &FleetWrapper::SaveModelOneTable)
.def("recv_and_save_model", &FleetWrapper::RecvAndSaveTable)
.def("sparse_table_stat", &FleetWrapper::PrintTableStat)
.def("stop_server", &FleetWrapper::StopServer)
.def("stop_worker", &FleetWrapper::FinalizeWorker)
......
......@@ -545,7 +545,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None, mode=1):
def save_persistables(self, executor, dirname, main_program=None, mode=0):
"""
saves all persistable tensors from :code:`main_program` to
......
......@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_main = compiled_config.origin_main_program.clone()
_startup = compiled_config.origin_startup_program.clone()
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
_add_lr_decay_table_pass(
_main, compiled_config,
self.user_defined_strategy.a_sync_configs["lr_decay_steps"])
if not compiled_config.is_geo_mode():
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
_add_lr_decay_table_pass(
_main, compiled_config,
self.user_defined_strategy.a_sync_configs["lr_decay_steps"])
# for main program
_main = worker.delete_optimizer_pass(_main, compiled_config)
_main = worker.distributed_ops_pass(_main, compiled_config)
......
......@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid
def _save_sparse_params(self, executor, dirname, context, main_program):
def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
distributed_varnames = get_sparse_tablenames(
self.compiled_strategy.origin_main_program, True)
values = []
for id, names in context.items():
if names not in distributed_varnames:
# only save sparse param to local
self._worker.recv_and_save_model(id, dirname)
# save sparse & distributed param on server
self._worker.save_one_model(id, dirname, mode)
values.extend(names)
self._worker.save_one_model(id, dirname, 0)
return values
def _save_distributed_persistables(self, executor, dirname, main_program,
mode):
def _save_distributed_persistables(self,
executor,
dirname,
main_program,
mode=0):
denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True,
......@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase):
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
recv_sparse_varnames = self._save_sparse_params(executor, dirname,
sparses, main_program)
sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
main_program, mode)
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
saved_varnames = recv_sparse_varnames
saved_varnames = sparse_varnames
remaining_vars = list(
filter(
......@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
# Todo(MrChengmo): Save optimizer status
self._save_distributed_persistables(executor, dirname, main_program,
mode)
......@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables(
executor, dirname, program, mode=0)
self._ps_inference_save_persistables(executor, dirname, program)
def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册