未验证 提交 7b29c89b 编写于 作者: Z zhaocaibei123 提交者: GitHub

refine save/load interface for distributed cpups (#44862)

* save load

* save load

* add unittest

* first commit

* second commit

* third commit

* remove SaveLocalFS in memory sparse table

* save dense param

* update

* push slot

* fix push show clk: int -> float

* add unittest

* fix sample

* unittest

* add AsExtra for op

* unittest

* modify fs.py

* modify fs.py

* fix some bugs

* add dataset hdfs config

* local change

* dataset use differenct hadoop ugi/fs_name

* add

* fix conflict

* fix

* remove logs

* code style

* fix

* code style

* code style

* fix

* code style

* save_dense_param

* fix

* fix

* fix

* fix

* change momentum in dense optimzer

* fix

* fix

* change fluid => paddle.static

* remove some unuseful code
Co-authored-by: Nesythan <esythan@126.com>
上级 8185cecd
......@@ -254,9 +254,9 @@ class DAdamD2Sum : public DenseOptimizer {
scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale);
scale = scale.cwiseSqrt();
mat_mom_velocity =
(mat_mom_velocity - mat_grad) * mom_decay_rate[0] + mat_grad;
(mat_mom_velocity + mat_grad) * mom_decay_rate[0] - mat_grad;
mat_w -= learning_rate[0] * mat_mom_velocity.cwiseProduct(scale);
mat_w += learning_rate[0] * mat_mom_velocity.cwiseProduct(scale);
}
float* learning_rate;
......@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer {
}
float* summary_decay_rate;
double summary_decay_rate_d = 0.9999999;
double summary_decay_rate_d = 0.999999;
float* param;
};
......
......@@ -47,7 +47,6 @@ void SparseAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim();
_accessor_info.select_dim = 1 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
;
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
......@@ -231,11 +230,13 @@ int32_t SparseAccessor::Update(float** update_values,
_embed_sgd_rule->UpdateValue(
update_value + sparse_feature_value.EmbedWIndex(),
update_value + sparse_feature_value.EmbedG2SumIndex(),
push_value + SparsePushValue::EmbedGIndex());
push_value + SparsePushValue::EmbedGIndex(),
push_show);
_embedx_sgd_rule->UpdateValue(
update_value + sparse_feature_value.EmbedxWIndex(),
update_value + sparse_feature_value.EmbedxG2SumIndex(),
push_value + SparsePushValue::EmbedxGIndex());
push_value + SparsePushValue::EmbedxGIndex(),
push_show);
}
return 0;
}
......
......@@ -120,7 +120,7 @@ message TableParameter {
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
optional bool enable_revert = 13 [ default = true ];
optional bool enable_revert = 13 [ default = false ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
}
......
......@@ -76,7 +76,12 @@ distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model
save_persistables = fleet.save_persistables
save_cache_model = fleet.save_cache_model
check_save_pre_patch_done = fleet.check_save_pre_patch_done
save_one_table = fleet.save_one_table
save_dense_params = fleet.save_dense_params
load_model = fleet.load_model
load_inference_model = fleet.load_inference_model
load_one_table = fleet.load_one_table
minimize = fleet.minimize
distributed_model = fleet.distributed_model
step = fleet.step
......
......@@ -710,10 +710,60 @@ class Fleet(object):
# build net
# fleet.distributed_optimizer(...)
fleet.load_model("path", "mode")
fleet.load_model("path", mode=0)
"""
self._runtime_handle.load_model(path, mode)
self._runtime_handle._load_persistables(path, mode)
@is_non_distributed_check
@inited_runtime_handler
def load_one_table(self, table_id, path, mode):
"""
load fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_one_table(0, "path", mode=0)
"""
self._runtime_handle._load_one_table(table_id, path, mode)
@is_non_distributed_check
@inited_runtime_handler
def load_inference_model(self, path, mode):
"""
load fleet inference model from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_inference_model("path", mode=1)
"""
self._runtime_handle._load_inference_model(path, mode)
@is_non_distributed_check
@inited_runtime_handler
......@@ -906,6 +956,70 @@ class Fleet(object):
def save_cache_model(self, dirname, **configs):
return self._runtime_handle._save_cache_model(dirname, **configs)
@is_non_distributed_check
@inited_runtime_handler
def check_save_pre_patch_done(self):
return self._runtime_handle._check_save_pre_patch_done()
@is_non_distributed_check
@inited_runtime_handler
def save_one_table(self, table_id, path, mode):
"""
save fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.save_one_table(0, "path", mode=0)
"""
self._runtime_handle._save_one_table(table_id, path, mode)
@is_non_distributed_check
@inited_runtime_handler
def save_dense_params(self,
executor,
dirname,
scope,
program,
var_names=None):
"""
save fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
import paddle
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
# build net
# fleet.distributed_optimizer(...)
fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program())
"""
self._runtime_handle._save_dense_params(executor, dirname, scope,
program, var_names)
def shrink(self, threshold=None):
self._runtime_handle._shrink(threshold)
......
......@@ -486,6 +486,7 @@ class HDFSClient(FS):
time.sleep(retry_sleep_second)
if ret == 134:
raise FSShellCmdAborted(cmd)
return ret, output.splitlines()
@_handle_errors()
......@@ -615,10 +616,12 @@ class HDFSClient(FS):
def _is_dir(self, fs_path):
cmd = "test -d {}".format(fs_path, redirect_stderr=True)
ret, lines = self._run_cmd(cmd)
ret, lines = self._run_cmd(cmd, retry_times=1)
if ret:
# other error
if self._test_match(lines):
print('raise exception: ')
print('\n'.join(lines))
raise ExecuteError(cmd)
return False
......@@ -682,13 +685,10 @@ class HDFSClient(FS):
client = HDFSClient(hadoop_home, configs)
ret = client.is_exist("hdfs:/test_hdfs_client")
"""
cmd = "ls {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True)
cmd = "test -e {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True, retry_times=1)
if ret != 0:
for l in out:
if "No such file or directory" in l:
return False
raise ExecuteError(cmd)
return False
return True
......@@ -712,7 +712,7 @@ class HDFSClient(FS):
self._try_upload(local_dir, dest_dir)
# can't retry
def upload(self, local_path, fs_path, multi_processes=1, overwrite=False):
def upload(self, local_path, fs_path, multi_processes=5, overwrite=False):
"""
Upload the local path to remote HDFS.
......@@ -766,11 +766,7 @@ class HDFSClient(FS):
local = LocalFS()
if not local.is_exist(local_path):
raise FSFileNotExistsError("{} not exists".format(local_path))
# upload_dir
if local.is_dir(local_path):
self.upload_dir(local_path, fs_path, overwrite=overwrite)
return
# upload files
all_files = get_local_files(local_path)
if not all_files:
print("there are nothing need to upload, function exit")
......@@ -805,7 +801,7 @@ class HDFSClient(FS):
raise e
# can't retry
def download(self, fs_path, local_path, multi_processes=1, overwrite=False):
def download(self, fs_path, local_path, multi_processes=5, overwrite=False):
"""
Download remote HDFS path to the local.
......@@ -962,7 +958,7 @@ class HDFSClient(FS):
cmd = "mv {} {}".format(fs_src_path, fs_dst_path)
ret = 0
try:
ret, _ = self._run_cmd(cmd)
ret, _ = self._run_cmd(cmd, retry_times=1)
if ret != 0:
raise ExecuteError(cmd)
except Exception as e:
......@@ -1090,7 +1086,7 @@ class HDFSClient(FS):
@_handle_errors()
def _try_cat(self, fs_path):
cmd = "cat {}".format(fs_path)
ret, output = self._run_cmd(cmd)
ret, output = self._run_cmd(cmd, retry_times=1)
if ret != 0:
raise ExecuteError(cmd)
return output
......
......@@ -591,6 +591,10 @@ class SparseTable(Table):
table_proto.table_class = self.table_class
table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE
table_proto.shard_num = self.shard_num
if table_proto.sparse_table_cache_file_num > len(
get_ps_endpoints(self.context['role_maker'])):
table_proto.sparse_table_cache_file_num = len(
get_ps_endpoints(self.context['role_maker']))
self.common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]]
......@@ -914,6 +918,7 @@ class TheOnePSRuntime(RuntimeBase):
self.ps_desc_builder = PsDescBuilder(self.context)
def _init_all_params(self, scopes, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items():
if ctx.is_sparse():
continue
......@@ -923,8 +928,11 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id]
#print("init params:", idx, table_id, var_names)
self._worker.push_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _pull_all_dense(self, scopes, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items():
if ctx.is_sparse():
continue
......@@ -934,8 +942,11 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id]
#print("pull all dense:", idx, table_id, var_names)
self._worker.pull_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _init_params(self, program, scope, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items():
if ctx.is_sparse():
continue
......@@ -945,8 +956,11 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id]
# print("init params:", table_id, var_names)
self._worker.push_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _pull_dense(self, program, scope, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items():
if ctx.is_sparse():
continue
......@@ -956,6 +970,8 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id]
# print("pull dense:", table_id, var_names)
self._worker.pull_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _init_worker(self, scopes=None):
worker_desc = self.ps_desc_builder.build_worker_desc()
......@@ -1208,6 +1224,32 @@ class TheOnePSRuntime(RuntimeBase):
model_path = os.path.join(dirname, "dnn_plugin")
return model_path
def _ps_save_dense_params(self,
executor,
dirname,
scope,
program,
var_names=None):
dense_map = get_the_one_recv_context(
self.context, split_dense_table=self.is_heter_ps_mode)
send_ctx = get_the_one_send_context(
self.context,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=self.is_heter_ps_mode,
ep_list=self.endpoints)
if program is None or len(self.origin_main_programs) == 1:
program = self.origin_main_programs[0]
dense_var_names = self._pull_dense(program, scope, send_ctx, dense_map)
save_var_names = dense_var_names if var_names is None else var_names
vars = [program.global_block().var(i) for i in save_var_names]
import paddle
with paddle.static.scope_guard(scope):
paddle.static.save_vars(executor,
"./",
program,
vars=vars,
filename=dirname)
def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
......@@ -1230,49 +1272,9 @@ class TheOnePSRuntime(RuntimeBase):
def _save_distributed_persistables(self,
executor,
dirname,
main_program,
mode=0):
denses = get_the_one_recv_context(
self.context,
is_dense=True,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparses = get_the_one_recv_context(
self.context,
is_dense=False,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
main_program, mode)
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
self._communicator.pull_dense(denses)
saved_varnames = sparse_varnames
remaining_vars = list(
filter(TheOnePSRuntime.__exclude_vars(saved_varnames),
main_program.list_vars()))
import paddle
for var in remaining_vars:
# if var.name not in recv_dense_varnames:
# continue
tensor = var.get_value()
paddle.save(tensor,
os.path.join(dirname, var.name),
use_binary_format=True)
def _ps_inference_save_persistables(self,
executor,
dirname,
main_program=None,
mode=0,
**kwargs):
main_program=None,
mode=0,
**kwargs):
"""
This function filters out all variables with `persistable==True` from the
give `main_program` and then saves these variables to the folder `dirname`
......@@ -1301,9 +1303,6 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
# Todo(MrChengmo): Save optimizer status
# self._save_distributed_persistables(executor, dirname, main_program,
# mode)
self._worker.save_all_model(dirname, mode)
def _ps_inference_save_inference_model(self,
......@@ -1384,14 +1383,8 @@ class TheOnePSRuntime(RuntimeBase):
os.path.join(model_path, var.name),
use_binary_format=True)
def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs)
def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs)
def _save_cache_model(self, dirname, **kwargs):
mode = kwargs.get("mode", 0)
mode = kwargs.get("mode", 1)
table_id = kwargs.get("table_id", 0)
self._worker.client_flush()
fleet.util.barrier()
......@@ -1414,6 +1407,12 @@ class TheOnePSRuntime(RuntimeBase):
fleet.util.barrier()
return feasign_num
def _check_save_pre_patch_done(self):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.check_save_pre_patch_done()
fleet.util.barrier()
def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
......@@ -1469,10 +1468,7 @@ class TheOnePSRuntime(RuntimeBase):
filter(TheOnePSRuntime.__exclude_vars(loaded_varnames),
main_program.list_vars()))
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_path = self._get_inference_model_path(dirname)
import paddle
for var in remaining_vars:
if var.name not in recv_dense_varnames:
......@@ -1482,14 +1478,40 @@ class TheOnePSRuntime(RuntimeBase):
self._init_params(main_program, scope, send_ctx, dense_map)
def _load_distributed_persistables(self, path, mode):
self._worker.load_model(path, mode)
def _save_one_table(self, table_id, path, mode):
if self.role_maker._is_first_worker():
self._worker.save_one_model(table_id, path, mode)
fleet.util.barrier()
def load_model(self, path, mode):
if mode == 0 or mode == 3:
self._load_distributed_persistables(path, mode)
else:
def _save_dense_params(self, *args, **kwargs):
if self.role_maker._is_first_worker():
self._ps_save_dense_params(*args, **kwargs)
fleet.util.barrier()
def _save_persistables(self, *args, **kwargs):
if self.role_maker._is_first_worker():
self._save_distributed_persistables(*args, **kwargs)
fleet.util.barrier()
def _save_inference_model(self, *args, **kwargs):
if self.role_maker._is_first_worker():
self._ps_inference_save_inference_model(*args, **kwargs)
fleet.util.barrier()
def _load_one_table(self, table_id, path, mode):
if self.role_maker._is_first_worker():
self._worker.load_one_table(table_id, path, mode)
fleet.util.barrier()
def _load_persistables(self, path, mode):
if self.role_maker._is_first_worker():
self._worker.load_model(path, mode)
fleet.util.barrier()
def _load_inference_model(self, path, mode):
if self.role_maker._is_first_worker():
self._ps_inference_load_inference_model(path, mode)
fleet.util.barrier()
def _shrink(self, threshold=None):
if threshold is not None:
......
......@@ -236,7 +236,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
if fleet.is_first_worker():
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
def do_dataset_training_queuedataset(self, fleet):
......@@ -277,7 +278,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
if fleet.is_first_worker():
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
dirname = os.getenv("SAVE_DIRNAME", None)
......@@ -327,16 +329,35 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
fleet.load_inference_model(model_dir, mode=0)
if fleet.is_first_worker():
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
dirname = os.getenv("SAVE_DIRNAME", None)
if dirname:
fleet.save_persistables(exe, dirname=dirname)
fleet.load_model(dirname, mode=0)
cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None)
if cache_dirname:
fleet.save_cache_model(cache_dirname)
dense_param_dirname = os.getenv("SAVE_DENSE_PARAM_DIRNAME", None)
if dense_param_dirname:
fleet.save_dense_params(exe, dense_param_dirname,
fluid.global_scope(),
fluid.default_main_program())
save_one_table_dirname = os.getenv("SAVE_ONE_TABLE_DIRNAME", None)
if save_one_table_dirname:
fleet.save_one_table(0, save_one_table_dirname, 0)
fleet.load_one_table(0, save_one_table_dirname, 0)
patch_dirname = os.getenv("SAVE_PATCH_DIRNAME", None)
if patch_dirname:
fleet.save_persistables(exe, patch_dirname, None, 5)
fleet.check_save_pre_patch_done()
if __name__ == "__main__":
runtime_main(TestDistCTR2x2)
......@@ -91,7 +91,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
if fleet.is_first_worker():
self.check_model_right(model_dir)
if fleet.is_first_worker():
fleet.save_persistables(executor=exe, dirname=model_dir)
shutil.rmtree(model_dir)
......@@ -139,7 +140,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
if fleet.is_first_worker():
self.check_model_right(model_dir)
if fleet.is_first_worker():
fleet.save_persistables(executor=exe, dirname=model_dir)
shutil.rmtree(model_dir)
......
......@@ -40,8 +40,15 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase):
"http_proxy": "",
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"SAVE_DIRNAME": "/tmp/TestDistMnistAsyncInMemoryDataset2x2/model",
"SAVE_CACHE_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model",
"SAVE_DENSE_PARAM_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/dense_param",
"SAVE_ONE_TABLE_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/table_0",
"SAVE_PATCH_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/patch_model",
"LOG_PREFIX": self.__class__.__name__,
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册