未验证 提交 e7842ba6 编写于 作者: W wangguanqun 提交者: GitHub

save/load in ps runtime(the_one_ps) (#36097)

* add trainer desc config to distributed strategy

* code style modified

* data_feed set lod

* fix bug

* code style

* fix bug

* save load

* save load

* save unittest

* add unittest of the_one_ps

* unittest

* add todo in communicator sendsparse
上级 ef76f664
...@@ -283,6 +283,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -283,6 +283,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim); push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
} }
// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
// if padding_idx == padding in datareader, the server will core.
/*
for (size_t i = 0; i < tensor->rows().size(); ++i) {
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
if (real_id != 0) {
sparse_push_keys.push_back(real_id);
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
}
*/
++_async_call_num; ++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) { request_call_num, [this, request_call_num](void *done) {
...@@ -353,6 +365,17 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { ...@@ -353,6 +365,17 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
return; return;
} }
void Communicator::PullDense(const RecvCtxMap &recv_varname_to_ctx) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
return;
}
void Communicator::RpcProfilerControl() { void Communicator::RpcProfilerControl() {
if (trainer_id_ == 0) { if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) { if (!do_server_profiler_ && platform::IsProfileEnabled()) {
......
...@@ -271,6 +271,8 @@ class Communicator { ...@@ -271,6 +271,8 @@ class Communicator {
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0; virtual void Start() = 0;
virtual void Stop() = 0; virtual void Stop() = 0;
......
...@@ -279,18 +279,25 @@ int32_t CommonSparseTable::set_global_lr(float* lr) { ...@@ -279,18 +279,25 @@ int32_t CommonSparseTable::set_global_lr(float* lr) {
return 0; return 0;
} }
int32_t CommonSparseTable::load(const std::string& path, int32_t CommonSparseTable::load(const std::string& dirname,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS(); auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, auto varname = _config.common().table_name();
std::string var_store =
string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
std::string shard_var_pre =
string::Sprintf("%s.block%d", varname, _shard_idx);
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
LoadFromText(value_, meta_, _shard_idx, _shard_num, task_pool_size_,
&shard_values_); &shard_values_);
rwlock_->UNLock(); rwlock_->UNLock();
auto end = GetCurrentUS(); auto end = GetCurrentUS();
auto varname = _config.common().table_name(); VLOG(0) << "load " << varname << " with value: " << value_
VLOG(0) << "load " << varname << " with value: " << path << " , meta: " << meta_
<< " , meta: " << param
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds"; << " using: " << std::to_string((end - begin) / 1e+6) << " seconds";
return 0; return 0;
......
...@@ -158,7 +158,8 @@ void BindDistCommunicator(py::module* m) { ...@@ -158,7 +158,8 @@ void BindDistCommunicator(py::module* m) {
.def("start", &Communicator::Start) .def("start", &Communicator::Start)
.def("push_sparse_param", &Communicator::RpcSendSparseParam) .def("push_sparse_param", &Communicator::RpcSendSparseParam)
.def("is_running", &Communicator::IsRunning) .def("is_running", &Communicator::IsRunning)
.def("init_params", &Communicator::InitParams); .def("init_params", &Communicator::InitParams)
.def("pull_dense", &Communicator::PullDense);
// .def("recv", &Communicator::RecvNoBarrier); // .def("recv", &Communicator::RecvNoBarrier);
} }
......
...@@ -868,11 +868,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -868,11 +868,11 @@ class TheOnePSRuntime(RuntimeBase):
for var_name in load_varnames: for var_name in load_varnames:
table_id = sparse_table_maps[var_name] table_id = sparse_table_maps[var_name]
path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, # path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
"{}.block{}.txt".format(var_name, pserver_id)) # "{}.block{}.txt".format(var_name, pserver_id))
meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, # meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
"{}.block{}.meta".format(var_name, pserver_id)) # "{}.block{}.meta".format(var_name, pserver_id))
self._server.load_sparse(path, meta, table_id) self._server.load_sparse(dirname, "0", table_id)
def _run_server(self): def _run_server(self):
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
...@@ -967,8 +967,12 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -967,8 +967,12 @@ class TheOnePSRuntime(RuntimeBase):
TheOnePSRuntime.__exclude_vars(saved_varnames), TheOnePSRuntime.__exclude_vars(saved_varnames),
main_program.list_vars())) main_program.list_vars()))
self._communicator.pull_dense(denses)
import paddle import paddle
for var in remaining_vars: for var in remaining_vars:
if var.name not in recv_dense_varnames:
continue
tensor = var.get_value() tensor = var.get_value()
paddle.save( paddle.save(
tensor, os.path.join(dirname, var.name), use_binary_format=True) tensor, os.path.join(dirname, var.name), use_binary_format=True)
...@@ -1063,8 +1067,64 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1063,8 +1067,64 @@ class TheOnePSRuntime(RuntimeBase):
def _save_persistables(self, *args, **kwargs): def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs) self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
distributed_varnames = get_sparse_tablenames(
self.compiled_strategy.origin_main_program, True)
values = []
for id, names in context.items():
if names[0] not in distributed_varnames:
# TODO: only load sparse param from local
warnings.warn("varname is not in distributed_varnames, pass")
# load sparse & distributed param on server
self._worker.load_one_table(id, dirname, mode)
values.extend(names)
return values
def _load_distributed_persistables(self, dirname, main_program=None,
mode=0):
if main_program is None:
main_program = self.compiled_strategy.get_origin_ps_main_program()
if isinstance(main_program, CompiledProgram):
raise TypeError(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
sparses = self.compiled_strategy.get_the_one_recv_context(
is_dense=False,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
sparse_varnames = self._load_sparse_params(dirname, sparses,
main_program, mode)
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
loaded_varnames = sparse_varnames
remaining_vars = list(
filter(
TheOnePSRuntime.__exclude_vars(loaded_varnames),
main_program.list_vars()))
import paddle
for var in remaining_vars:
if var.name not in recv_dense_varnames:
continue
tensor = paddle.load(os.path.join(dirname, var.name))
var.set_value(tensor)
self._communicator.init_params(denses)
def load_model(self, path, mode): def load_model(self, path, mode):
self._worker.load_model(path, mode) self._load_distributed_persistables(path, mode=mode)
def _shrink(self, threshold): def _shrink(self, threshold):
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
......
...@@ -161,6 +161,9 @@ class Communicator(object): ...@@ -161,6 +161,9 @@ class Communicator(object):
def init_params(self, context): def init_params(self, context):
self.communicator_.init_params(context) self.communicator_.init_params(context)
def pull_dense(self, context):
self.communicator_.pull_dense(context)
def push_sparse_param(self, var_name, table_id=-1, scope=global_scope()): def push_sparse_param(self, var_name, table_id=-1, scope=global_scope()):
if not self.is_running(): if not self.is_running():
raise ValueError( raise ValueError(
......
...@@ -36,8 +36,13 @@ class TestFleetBase(unittest.TestCase): ...@@ -36,8 +36,13 @@ class TestFleetBase(unittest.TestCase):
input_x = paddle.fluid.layers.data( input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32') name="x", shape=[32], dtype='float32')
input_slot = paddle.fluid.layers.data(
name="slot", shape=[1], dtype='int64')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
emb = paddle.fluid.layers.embedding(
input=input_slot, size=[10, 9], is_sparse=True)
input_x = paddle.concat(x=[input_x, emb], axis=1)
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
...@@ -63,11 +68,14 @@ class TestFleetBase(unittest.TestCase): ...@@ -63,11 +68,14 @@ class TestFleetBase(unittest.TestCase):
compiled_prog = fluid.compiler.CompiledProgram( compiled_prog = fluid.compiler.CompiledProgram(
fluid.default_main_program()) fluid.default_main_program())
fleet.init_worker()
fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost]) fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost])
fleet.fleet.save( fleet.fleet.save(
dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost]) dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost])
fleet.fleet.save(dirname="/tmp") fleet.fleet.save(dirname="/tmp")
fleet.load_model(path="/tmp", mode=0)
self.assertRaises( self.assertRaises(
Exception, Exception,
fleet.save_inference_model, fleet.save_inference_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册