未验证 提交 7b29c89b 编写于 作者: Z zhaocaibei123 提交者: GitHub

refine save/load interface for distributed cpups (#44862)

* save load

* save load

* add unittest

* first commit

* second commit

* third commit

* remove SaveLocalFS in memory sparse table

* save dense param

* update

* push slot

* fix push show clk: int -> float

* add unittest

* fix sample

* unittest

* add AsExtra for op

* unittest

* modify fs.py

* modify fs.py

* fix some bugs

* add dataset hdfs config

* local change

* dataset use differenct hadoop ugi/fs_name

* add

* fix conflict

* fix

* remove logs

* code style

* fix

* code style

* code style

* fix

* code style

* save_dense_param

* fix

* fix

* fix

* fix

* change momentum in dense optimzer

* fix

* fix

* change fluid => paddle.static

* remove some unuseful code
Co-authored-by: Nesythan <esythan@126.com>
上级 8185cecd
...@@ -254,9 +254,9 @@ class DAdamD2Sum : public DenseOptimizer { ...@@ -254,9 +254,9 @@ class DAdamD2Sum : public DenseOptimizer {
scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale); scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale);
scale = scale.cwiseSqrt(); scale = scale.cwiseSqrt();
mat_mom_velocity = mat_mom_velocity =
(mat_mom_velocity - mat_grad) * mom_decay_rate[0] + mat_grad; (mat_mom_velocity + mat_grad) * mom_decay_rate[0] - mat_grad;
mat_w -= learning_rate[0] * mat_mom_velocity.cwiseProduct(scale); mat_w += learning_rate[0] * mat_mom_velocity.cwiseProduct(scale);
} }
float* learning_rate; float* learning_rate;
...@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer { ...@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer {
} }
float* summary_decay_rate; float* summary_decay_rate;
double summary_decay_rate_d = 0.9999999; double summary_decay_rate_d = 0.999999;
float* param; float* param;
}; };
......
...@@ -47,7 +47,6 @@ void SparseAccessor::InitAccessorInfo() { ...@@ -47,7 +47,6 @@ void SparseAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
_accessor_info.select_dim = 1 + embedx_dim; _accessor_info.select_dim = 1 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float); _accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
;
_accessor_info.update_dim = 4 + embedx_dim; _accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float); _accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size = _accessor_info.mf_size =
...@@ -231,11 +230,13 @@ int32_t SparseAccessor::Update(float** update_values, ...@@ -231,11 +230,13 @@ int32_t SparseAccessor::Update(float** update_values,
_embed_sgd_rule->UpdateValue( _embed_sgd_rule->UpdateValue(
update_value + sparse_feature_value.EmbedWIndex(), update_value + sparse_feature_value.EmbedWIndex(),
update_value + sparse_feature_value.EmbedG2SumIndex(), update_value + sparse_feature_value.EmbedG2SumIndex(),
push_value + SparsePushValue::EmbedGIndex()); push_value + SparsePushValue::EmbedGIndex(),
push_show);
_embedx_sgd_rule->UpdateValue( _embedx_sgd_rule->UpdateValue(
update_value + sparse_feature_value.EmbedxWIndex(), update_value + sparse_feature_value.EmbedxWIndex(),
update_value + sparse_feature_value.EmbedxG2SumIndex(), update_value + sparse_feature_value.EmbedxG2SumIndex(),
push_value + SparsePushValue::EmbedxGIndex()); push_value + SparsePushValue::EmbedxGIndex(),
push_show);
} }
return 0; return 0;
} }
......
...@@ -120,7 +120,7 @@ message TableParameter { ...@@ -120,7 +120,7 @@ message TableParameter {
optional bool enable_sparse_table_cache = 10 [ default = true ]; optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
optional bool enable_revert = 13 [ default = true ]; optional bool enable_revert = 13 [ default = false ];
optional float shard_merge_rate = 14 [ default = 1.0 ]; optional float shard_merge_rate = 14 [ default = 1.0 ];
} }
......
...@@ -76,7 +76,12 @@ distributed_optimizer = fleet.distributed_optimizer ...@@ -76,7 +76,12 @@ distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model save_inference_model = fleet.save_inference_model
save_persistables = fleet.save_persistables save_persistables = fleet.save_persistables
save_cache_model = fleet.save_cache_model save_cache_model = fleet.save_cache_model
check_save_pre_patch_done = fleet.check_save_pre_patch_done
save_one_table = fleet.save_one_table
save_dense_params = fleet.save_dense_params
load_model = fleet.load_model load_model = fleet.load_model
load_inference_model = fleet.load_inference_model
load_one_table = fleet.load_one_table
minimize = fleet.minimize minimize = fleet.minimize
distributed_model = fleet.distributed_model distributed_model = fleet.distributed_model
step = fleet.step step = fleet.step
......
...@@ -710,10 +710,60 @@ class Fleet(object): ...@@ -710,10 +710,60 @@ class Fleet(object):
# build net # build net
# fleet.distributed_optimizer(...) # fleet.distributed_optimizer(...)
fleet.load_model("path", "mode") fleet.load_model("path", mode=0)
""" """
self._runtime_handle.load_model(path, mode) self._runtime_handle._load_persistables(path, mode)
@is_non_distributed_check
@inited_runtime_handler
def load_one_table(self, table_id, path, mode):
"""
load fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_one_table(0, "path", mode=0)
"""
self._runtime_handle._load_one_table(table_id, path, mode)
@is_non_distributed_check
@inited_runtime_handler
def load_inference_model(self, path, mode):
"""
load fleet inference model from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_inference_model("path", mode=1)
"""
self._runtime_handle._load_inference_model(path, mode)
@is_non_distributed_check @is_non_distributed_check
@inited_runtime_handler @inited_runtime_handler
...@@ -906,6 +956,70 @@ class Fleet(object): ...@@ -906,6 +956,70 @@ class Fleet(object):
def save_cache_model(self, dirname, **configs): def save_cache_model(self, dirname, **configs):
return self._runtime_handle._save_cache_model(dirname, **configs) return self._runtime_handle._save_cache_model(dirname, **configs)
@is_non_distributed_check
@inited_runtime_handler
def check_save_pre_patch_done(self):
return self._runtime_handle._check_save_pre_patch_done()
@is_non_distributed_check
@inited_runtime_handler
def save_one_table(self, table_id, path, mode):
"""
save fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.save_one_table(0, "path", mode=0)
"""
self._runtime_handle._save_one_table(table_id, path, mode)
@is_non_distributed_check
@inited_runtime_handler
def save_dense_params(self,
executor,
dirname,
scope,
program,
var_names=None):
"""
save fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
import paddle
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
# build net
# fleet.distributed_optimizer(...)
fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program())
"""
self._runtime_handle._save_dense_params(executor, dirname, scope,
program, var_names)
def shrink(self, threshold=None): def shrink(self, threshold=None):
self._runtime_handle._shrink(threshold) self._runtime_handle._shrink(threshold)
......
...@@ -486,6 +486,7 @@ class HDFSClient(FS): ...@@ -486,6 +486,7 @@ class HDFSClient(FS):
time.sleep(retry_sleep_second) time.sleep(retry_sleep_second)
if ret == 134: if ret == 134:
raise FSShellCmdAborted(cmd) raise FSShellCmdAborted(cmd)
return ret, output.splitlines() return ret, output.splitlines()
@_handle_errors() @_handle_errors()
...@@ -615,10 +616,12 @@ class HDFSClient(FS): ...@@ -615,10 +616,12 @@ class HDFSClient(FS):
def _is_dir(self, fs_path): def _is_dir(self, fs_path):
cmd = "test -d {}".format(fs_path, redirect_stderr=True) cmd = "test -d {}".format(fs_path, redirect_stderr=True)
ret, lines = self._run_cmd(cmd) ret, lines = self._run_cmd(cmd, retry_times=1)
if ret: if ret:
# other error # other error
if self._test_match(lines): if self._test_match(lines):
print('raise exception: ')
print('\n'.join(lines))
raise ExecuteError(cmd) raise ExecuteError(cmd)
return False return False
...@@ -682,13 +685,10 @@ class HDFSClient(FS): ...@@ -682,13 +685,10 @@ class HDFSClient(FS):
client = HDFSClient(hadoop_home, configs) client = HDFSClient(hadoop_home, configs)
ret = client.is_exist("hdfs:/test_hdfs_client") ret = client.is_exist("hdfs:/test_hdfs_client")
""" """
cmd = "ls {} ".format(fs_path) cmd = "test -e {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True) ret, out = self._run_cmd(cmd, redirect_stderr=True, retry_times=1)
if ret != 0: if ret != 0:
for l in out: return False
if "No such file or directory" in l:
return False
raise ExecuteError(cmd)
return True return True
...@@ -712,7 +712,7 @@ class HDFSClient(FS): ...@@ -712,7 +712,7 @@ class HDFSClient(FS):
self._try_upload(local_dir, dest_dir) self._try_upload(local_dir, dest_dir)
# can't retry # can't retry
def upload(self, local_path, fs_path, multi_processes=1, overwrite=False): def upload(self, local_path, fs_path, multi_processes=5, overwrite=False):
""" """
Upload the local path to remote HDFS. Upload the local path to remote HDFS.
...@@ -766,11 +766,7 @@ class HDFSClient(FS): ...@@ -766,11 +766,7 @@ class HDFSClient(FS):
local = LocalFS() local = LocalFS()
if not local.is_exist(local_path): if not local.is_exist(local_path):
raise FSFileNotExistsError("{} not exists".format(local_path)) raise FSFileNotExistsError("{} not exists".format(local_path))
# upload_dir
if local.is_dir(local_path):
self.upload_dir(local_path, fs_path, overwrite=overwrite)
return
# upload files
all_files = get_local_files(local_path) all_files = get_local_files(local_path)
if not all_files: if not all_files:
print("there are nothing need to upload, function exit") print("there are nothing need to upload, function exit")
...@@ -805,7 +801,7 @@ class HDFSClient(FS): ...@@ -805,7 +801,7 @@ class HDFSClient(FS):
raise e raise e
# can't retry # can't retry
def download(self, fs_path, local_path, multi_processes=1, overwrite=False): def download(self, fs_path, local_path, multi_processes=5, overwrite=False):
""" """
Download remote HDFS path to the local. Download remote HDFS path to the local.
...@@ -962,7 +958,7 @@ class HDFSClient(FS): ...@@ -962,7 +958,7 @@ class HDFSClient(FS):
cmd = "mv {} {}".format(fs_src_path, fs_dst_path) cmd = "mv {} {}".format(fs_src_path, fs_dst_path)
ret = 0 ret = 0
try: try:
ret, _ = self._run_cmd(cmd) ret, _ = self._run_cmd(cmd, retry_times=1)
if ret != 0: if ret != 0:
raise ExecuteError(cmd) raise ExecuteError(cmd)
except Exception as e: except Exception as e:
...@@ -1090,7 +1086,7 @@ class HDFSClient(FS): ...@@ -1090,7 +1086,7 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def _try_cat(self, fs_path): def _try_cat(self, fs_path):
cmd = "cat {}".format(fs_path) cmd = "cat {}".format(fs_path)
ret, output = self._run_cmd(cmd) ret, output = self._run_cmd(cmd, retry_times=1)
if ret != 0: if ret != 0:
raise ExecuteError(cmd) raise ExecuteError(cmd)
return output return output
......
...@@ -591,6 +591,10 @@ class SparseTable(Table): ...@@ -591,6 +591,10 @@ class SparseTable(Table):
table_proto.table_class = self.table_class table_proto.table_class = self.table_class
table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE
table_proto.shard_num = self.shard_num table_proto.shard_num = self.shard_num
if table_proto.sparse_table_cache_file_num > len(
get_ps_endpoints(self.context['role_maker'])):
table_proto.sparse_table_cache_file_num = len(
get_ps_endpoints(self.context['role_maker']))
self.common.table_name = self.context['grad_name_to_param_name'][ self.common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]] ctx.origin_varnames()[0]]
...@@ -914,6 +918,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -914,6 +918,7 @@ class TheOnePSRuntime(RuntimeBase):
self.ps_desc_builder = PsDescBuilder(self.context) self.ps_desc_builder = PsDescBuilder(self.context)
def _init_all_params(self, scopes, send_ctx, recv_map): def _init_all_params(self, scopes, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items(): for name, ctx in send_ctx.items():
if ctx.is_sparse(): if ctx.is_sparse():
continue continue
...@@ -923,8 +928,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -923,8 +928,11 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id] var_names = recv_map[table_id]
#print("init params:", idx, table_id, var_names) #print("init params:", idx, table_id, var_names)
self._worker.push_dense_params(scope, table_id, var_names) self._worker.push_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _pull_all_dense(self, scopes, send_ctx, recv_map): def _pull_all_dense(self, scopes, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items(): for name, ctx in send_ctx.items():
if ctx.is_sparse(): if ctx.is_sparse():
continue continue
...@@ -934,8 +942,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -934,8 +942,11 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id] var_names = recv_map[table_id]
#print("pull all dense:", idx, table_id, var_names) #print("pull all dense:", idx, table_id, var_names)
self._worker.pull_dense_params(scope, table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _init_params(self, program, scope, send_ctx, recv_map): def _init_params(self, program, scope, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items(): for name, ctx in send_ctx.items():
if ctx.is_sparse(): if ctx.is_sparse():
continue continue
...@@ -945,8 +956,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -945,8 +956,11 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id] var_names = recv_map[table_id]
# print("init params:", table_id, var_names) # print("init params:", table_id, var_names)
self._worker.push_dense_params(scope, table_id, var_names) self._worker.push_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _pull_dense(self, program, scope, send_ctx, recv_map): def _pull_dense(self, program, scope, send_ctx, recv_map):
all_var_names = []
for name, ctx in send_ctx.items(): for name, ctx in send_ctx.items():
if ctx.is_sparse(): if ctx.is_sparse():
continue continue
...@@ -956,6 +970,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -956,6 +970,8 @@ class TheOnePSRuntime(RuntimeBase):
var_names = recv_map[table_id] var_names = recv_map[table_id]
# print("pull dense:", table_id, var_names) # print("pull dense:", table_id, var_names)
self._worker.pull_dense_params(scope, table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names)
all_var_names.extend(var_names)
return all_var_names
def _init_worker(self, scopes=None): def _init_worker(self, scopes=None):
worker_desc = self.ps_desc_builder.build_worker_desc() worker_desc = self.ps_desc_builder.build_worker_desc()
...@@ -1208,6 +1224,32 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1208,6 +1224,32 @@ class TheOnePSRuntime(RuntimeBase):
model_path = os.path.join(dirname, "dnn_plugin") model_path = os.path.join(dirname, "dnn_plugin")
return model_path return model_path
def _ps_save_dense_params(self,
executor,
dirname,
scope,
program,
var_names=None):
dense_map = get_the_one_recv_context(
self.context, split_dense_table=self.is_heter_ps_mode)
send_ctx = get_the_one_send_context(
self.context,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=self.is_heter_ps_mode,
ep_list=self.endpoints)
if program is None or len(self.origin_main_programs) == 1:
program = self.origin_main_programs[0]
dense_var_names = self._pull_dense(program, scope, send_ctx, dense_map)
save_var_names = dense_var_names if var_names is None else var_names
vars = [program.global_block().var(i) for i in save_var_names]
import paddle
with paddle.static.scope_guard(scope):
paddle.static.save_vars(executor,
"./",
program,
vars=vars,
filename=dirname)
def _save_sparse_params(self, executor, dirname, context, main_program, def _save_sparse_params(self, executor, dirname, context, main_program,
mode): mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_programs, distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
...@@ -1230,49 +1272,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1230,49 +1272,9 @@ class TheOnePSRuntime(RuntimeBase):
def _save_distributed_persistables(self, def _save_distributed_persistables(self,
executor, executor,
dirname, dirname,
main_program, main_program=None,
mode=0): mode=0,
**kwargs):
denses = get_the_one_recv_context(
self.context,
is_dense=True,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparses = get_the_one_recv_context(
self.context,
is_dense=False,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
main_program, mode)
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
self._communicator.pull_dense(denses)
saved_varnames = sparse_varnames
remaining_vars = list(
filter(TheOnePSRuntime.__exclude_vars(saved_varnames),
main_program.list_vars()))
import paddle
for var in remaining_vars:
# if var.name not in recv_dense_varnames:
# continue
tensor = var.get_value()
paddle.save(tensor,
os.path.join(dirname, var.name),
use_binary_format=True)
def _ps_inference_save_persistables(self,
executor,
dirname,
main_program=None,
mode=0,
**kwargs):
""" """
This function filters out all variables with `persistable==True` from the This function filters out all variables with `persistable==True` from the
give `main_program` and then saves these variables to the folder `dirname` give `main_program` and then saves these variables to the folder `dirname`
...@@ -1301,9 +1303,6 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1301,9 +1303,6 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
# Todo(MrChengmo): Save optimizer status
# self._save_distributed_persistables(executor, dirname, main_program,
# mode)
self._worker.save_all_model(dirname, mode) self._worker.save_all_model(dirname, mode)
def _ps_inference_save_inference_model(self, def _ps_inference_save_inference_model(self,
...@@ -1384,14 +1383,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1384,14 +1383,8 @@ class TheOnePSRuntime(RuntimeBase):
os.path.join(model_path, var.name), os.path.join(model_path, var.name),
use_binary_format=True) use_binary_format=True)
def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs)
def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs)
def _save_cache_model(self, dirname, **kwargs): def _save_cache_model(self, dirname, **kwargs):
mode = kwargs.get("mode", 0) mode = kwargs.get("mode", 1)
table_id = kwargs.get("table_id", 0) table_id = kwargs.get("table_id", 0)
self._worker.client_flush() self._worker.client_flush()
fleet.util.barrier() fleet.util.barrier()
...@@ -1414,6 +1407,12 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1414,6 +1407,12 @@ class TheOnePSRuntime(RuntimeBase):
fleet.util.barrier() fleet.util.barrier()
return feasign_num return feasign_num
def _check_save_pre_patch_done(self):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.check_save_pre_patch_done()
fleet.util.barrier()
def _load_sparse_params(self, dirname, context, main_program, mode): def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_programs, distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True) True)
...@@ -1469,10 +1468,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1469,10 +1468,7 @@ class TheOnePSRuntime(RuntimeBase):
filter(TheOnePSRuntime.__exclude_vars(loaded_varnames), filter(TheOnePSRuntime.__exclude_vars(loaded_varnames),
main_program.list_vars())) main_program.list_vars()))
if dirname.startswith("afs:") or dirname.startswith("hdfs:"): model_path = self._get_inference_model_path(dirname)
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
import paddle import paddle
for var in remaining_vars: for var in remaining_vars:
if var.name not in recv_dense_varnames: if var.name not in recv_dense_varnames:
...@@ -1482,14 +1478,40 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1482,14 +1478,40 @@ class TheOnePSRuntime(RuntimeBase):
self._init_params(main_program, scope, send_ctx, dense_map) self._init_params(main_program, scope, send_ctx, dense_map)
def _load_distributed_persistables(self, path, mode): def _save_one_table(self, table_id, path, mode):
self._worker.load_model(path, mode) if self.role_maker._is_first_worker():
self._worker.save_one_model(table_id, path, mode)
fleet.util.barrier()
def load_model(self, path, mode): def _save_dense_params(self, *args, **kwargs):
if mode == 0 or mode == 3: if self.role_maker._is_first_worker():
self._load_distributed_persistables(path, mode) self._ps_save_dense_params(*args, **kwargs)
else: fleet.util.barrier()
def _save_persistables(self, *args, **kwargs):
if self.role_maker._is_first_worker():
self._save_distributed_persistables(*args, **kwargs)
fleet.util.barrier()
def _save_inference_model(self, *args, **kwargs):
if self.role_maker._is_first_worker():
self._ps_inference_save_inference_model(*args, **kwargs)
fleet.util.barrier()
def _load_one_table(self, table_id, path, mode):
if self.role_maker._is_first_worker():
self._worker.load_one_table(table_id, path, mode)
fleet.util.barrier()
def _load_persistables(self, path, mode):
if self.role_maker._is_first_worker():
self._worker.load_model(path, mode)
fleet.util.barrier()
def _load_inference_model(self, path, mode):
if self.role_maker._is_first_worker():
self._ps_inference_load_inference_model(path, mode) self._ps_inference_load_inference_model(path, mode)
fleet.util.barrier()
def _shrink(self, threshold=None): def _shrink(self, threshold=None):
if threshold is not None: if threshold is not None:
......
...@@ -236,7 +236,8 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -236,7 +236,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds], [feed.name for feed in self.feeds],
self.avg_cost) self.avg_cost)
self.check_model_right(model_dir) if fleet.is_first_worker():
self.check_model_right(model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
def do_dataset_training_queuedataset(self, fleet): def do_dataset_training_queuedataset(self, fleet):
...@@ -277,7 +278,8 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -277,7 +278,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds], [feed.name for feed in self.feeds],
self.avg_cost) self.avg_cost)
self.check_model_right(model_dir) if fleet.is_first_worker():
self.check_model_right(model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
dirname = os.getenv("SAVE_DIRNAME", None) dirname = os.getenv("SAVE_DIRNAME", None)
...@@ -327,16 +329,35 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -327,16 +329,35 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds], [feed.name for feed in self.feeds],
self.avg_cost) self.avg_cost)
self.check_model_right(model_dir) fleet.load_inference_model(model_dir, mode=0)
if fleet.is_first_worker():
self.check_model_right(model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
dirname = os.getenv("SAVE_DIRNAME", None) dirname = os.getenv("SAVE_DIRNAME", None)
if dirname: if dirname:
fleet.save_persistables(exe, dirname=dirname) fleet.save_persistables(exe, dirname=dirname)
fleet.load_model(dirname, mode=0)
cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None) cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None)
if cache_dirname: if cache_dirname:
fleet.save_cache_model(cache_dirname) fleet.save_cache_model(cache_dirname)
dense_param_dirname = os.getenv("SAVE_DENSE_PARAM_DIRNAME", None)
if dense_param_dirname:
fleet.save_dense_params(exe, dense_param_dirname,
fluid.global_scope(),
fluid.default_main_program())
save_one_table_dirname = os.getenv("SAVE_ONE_TABLE_DIRNAME", None)
if save_one_table_dirname:
fleet.save_one_table(0, save_one_table_dirname, 0)
fleet.load_one_table(0, save_one_table_dirname, 0)
patch_dirname = os.getenv("SAVE_PATCH_DIRNAME", None)
if patch_dirname:
fleet.save_persistables(exe, patch_dirname, None, 5)
fleet.check_save_pre_patch_done()
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestDistCTR2x2) runtime_main(TestDistCTR2x2)
...@@ -91,7 +91,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): ...@@ -91,7 +91,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds], [feed.name for feed in self.feeds],
self.avg_cost) self.avg_cost)
self.check_model_right(model_dir) if fleet.is_first_worker():
self.check_model_right(model_dir)
if fleet.is_first_worker(): if fleet.is_first_worker():
fleet.save_persistables(executor=exe, dirname=model_dir) fleet.save_persistables(executor=exe, dirname=model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
...@@ -139,7 +140,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): ...@@ -139,7 +140,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds], [feed.name for feed in self.feeds],
self.avg_cost) self.avg_cost)
self.check_model_right(model_dir) if fleet.is_first_worker():
self.check_model_right(model_dir)
if fleet.is_first_worker(): if fleet.is_first_worker():
fleet.save_persistables(executor=exe, dirname=model_dir) fleet.save_persistables(executor=exe, dirname=model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
......
...@@ -40,8 +40,15 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase): ...@@ -40,8 +40,15 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase):
"http_proxy": "", "http_proxy": "",
"CPU_NUM": "2", "CPU_NUM": "2",
"LOG_DIRNAME": "/tmp", "LOG_DIRNAME": "/tmp",
"SAVE_DIRNAME": "/tmp/TestDistMnistAsyncInMemoryDataset2x2/model",
"SAVE_CACHE_DIRNAME": "SAVE_CACHE_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model", "/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model",
"SAVE_DENSE_PARAM_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/dense_param",
"SAVE_ONE_TABLE_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/table_0",
"SAVE_PATCH_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/patch_model",
"LOG_PREFIX": self.__class__.__name__, "LOG_PREFIX": self.__class__.__name__,
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册