From 7b29c89b268185ef410760e65fa1b5304dbd028b Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Tue, 9 Aug 2022 11:18:09 +0800 Subject: [PATCH] refine save/load interface for distributed cpups (#44862) * save load * save load * add unittest * first commit * second commit * third commit * remove SaveLocalFS in memory sparse table * save dense param * update * push slot * fix push show clk: int -> float * add unittest * fix sample * unittest * add AsExtra for op * unittest * modify fs.py * modify fs.py * fix some bugs * add dataset hdfs config * local change * dataset use differenct hadoop ugi/fs_name * add * fix conflict * fix * remove logs * code style * fix * code style * code style * fix * code style * save_dense_param * fix * fix * fix * fix * change momentum in dense optimzer * fix * fix * change fluid => paddle.static * remove some unuseful code Co-authored-by: esythan --- .../distributed/ps/table/depends/dense.h | 6 +- .../distributed/ps/table/sparse_accessor.cc | 7 +- paddle/fluid/distributed/the_one_ps.proto | 2 +- python/paddle/distributed/fleet/__init__.py | 5 + .../distributed/fleet/base/fleet_base.py | 118 +++++++++++++- python/paddle/distributed/fleet/utils/fs.py | 28 ++-- python/paddle/distributed/ps/the_one_ps.py | 148 ++++++++++-------- .../fluid/tests/unittests/dist_fleet_ctr.py | 27 +++- .../tests/unittests/dist_fleet_ctr_ps_gpu.py | 6 +- .../tests/unittests/test_dist_fleet_ctr.py | 7 + 10 files changed, 261 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/depends/dense.h b/paddle/fluid/distributed/ps/table/depends/dense.h index d98a91750f4..3f09376b42d 100644 --- a/paddle/fluid/distributed/ps/table/depends/dense.h +++ b/paddle/fluid/distributed/ps/table/depends/dense.h @@ -254,9 +254,9 @@ class DAdamD2Sum : public DenseOptimizer { scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale); scale = scale.cwiseSqrt(); mat_mom_velocity = - (mat_mom_velocity - mat_grad) * mom_decay_rate[0] + mat_grad; + (mat_mom_velocity + mat_grad) * mom_decay_rate[0] - mat_grad; - mat_w -= learning_rate[0] * mat_mom_velocity.cwiseProduct(scale); + mat_w += learning_rate[0] * mat_mom_velocity.cwiseProduct(scale); } float* learning_rate; @@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer { } float* summary_decay_rate; - double summary_decay_rate_d = 0.9999999; + double summary_decay_rate_d = 0.999999; float* param; }; diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.cc b/paddle/fluid/distributed/ps/table/sparse_accessor.cc index 2fbb58c469c..1591e340b9e 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.cc +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.cc @@ -47,7 +47,6 @@ void SparseAccessor::InitAccessorInfo() { auto embedx_dim = _config.embedx_dim(); _accessor_info.select_dim = 1 + embedx_dim; _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); - ; _accessor_info.update_dim = 4 + embedx_dim; _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); _accessor_info.mf_size = @@ -231,11 +230,13 @@ int32_t SparseAccessor::Update(float** update_values, _embed_sgd_rule->UpdateValue( update_value + sparse_feature_value.EmbedWIndex(), update_value + sparse_feature_value.EmbedG2SumIndex(), - push_value + SparsePushValue::EmbedGIndex()); + push_value + SparsePushValue::EmbedGIndex(), + push_show); _embedx_sgd_rule->UpdateValue( update_value + sparse_feature_value.EmbedxWIndex(), update_value + sparse_feature_value.EmbedxG2SumIndex(), - push_value + SparsePushValue::EmbedxGIndex()); + push_value + SparsePushValue::EmbedxGIndex(), + push_show); } return 0; } diff --git a/paddle/fluid/distributed/the_one_ps.proto b/paddle/fluid/distributed/the_one_ps.proto index e74502d7351..2241655465f 100755 --- a/paddle/fluid/distributed/the_one_ps.proto +++ b/paddle/fluid/distributed/the_one_ps.proto @@ -120,7 +120,7 @@ message TableParameter { optional bool enable_sparse_table_cache = 10 [ default = true ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; - optional bool enable_revert = 13 [ default = true ]; + optional bool enable_revert = 13 [ default = false ]; optional float shard_merge_rate = 14 [ default = 1.0 ]; } diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 0cfb946d3d8..8ac5b93ef67 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -76,7 +76,12 @@ distributed_optimizer = fleet.distributed_optimizer save_inference_model = fleet.save_inference_model save_persistables = fleet.save_persistables save_cache_model = fleet.save_cache_model +check_save_pre_patch_done = fleet.check_save_pre_patch_done +save_one_table = fleet.save_one_table +save_dense_params = fleet.save_dense_params load_model = fleet.load_model +load_inference_model = fleet.load_inference_model +load_one_table = fleet.load_one_table minimize = fleet.minimize distributed_model = fleet.distributed_model step = fleet.step diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 51e1c5281a8..52f3812d8a5 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -710,10 +710,60 @@ class Fleet(object): # build net # fleet.distributed_optimizer(...) - fleet.load_model("path", "mode") + fleet.load_model("path", mode=0) """ - self._runtime_handle.load_model(path, mode) + self._runtime_handle._load_persistables(path, mode) + + @is_non_distributed_check + @inited_runtime_handler + def load_one_table(self, table_id, path, mode): + """ + load fleet one table from path + + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + fleet.init() + + # build net + # fleet.distributed_optimizer(...) + + fleet.load_one_table(0, "path", mode=0) + + """ + self._runtime_handle._load_one_table(table_id, path, mode) + + @is_non_distributed_check + @inited_runtime_handler + def load_inference_model(self, path, mode): + """ + load fleet inference model from path + + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + fleet.init() + + # build net + # fleet.distributed_optimizer(...) + + fleet.load_inference_model("path", mode=1) + + """ + self._runtime_handle._load_inference_model(path, mode) @is_non_distributed_check @inited_runtime_handler @@ -906,6 +956,70 @@ class Fleet(object): def save_cache_model(self, dirname, **configs): return self._runtime_handle._save_cache_model(dirname, **configs) + @is_non_distributed_check + @inited_runtime_handler + def check_save_pre_patch_done(self): + return self._runtime_handle._check_save_pre_patch_done() + + @is_non_distributed_check + @inited_runtime_handler + def save_one_table(self, table_id, path, mode): + """ + save fleet one table from path + + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + fleet.init() + + # build net + # fleet.distributed_optimizer(...) + + fleet.save_one_table(0, "path", mode=0) + + """ + self._runtime_handle._save_one_table(table_id, path, mode) + + @is_non_distributed_check + @inited_runtime_handler + def save_dense_params(self, + executor, + dirname, + scope, + program, + var_names=None): + """ + save fleet one table from path + + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + fleet.init() + import paddle + place = paddle.fluid.CPUPlace() + exe = paddle.fluid.Executor(place) + + # build net + # fleet.distributed_optimizer(...) + + fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program()) + + """ + self._runtime_handle._save_dense_params(executor, dirname, scope, + program, var_names) + def shrink(self, threshold=None): self._runtime_handle._shrink(threshold) diff --git a/python/paddle/distributed/fleet/utils/fs.py b/python/paddle/distributed/fleet/utils/fs.py index af38d9f5138..95635154c33 100644 --- a/python/paddle/distributed/fleet/utils/fs.py +++ b/python/paddle/distributed/fleet/utils/fs.py @@ -486,6 +486,7 @@ class HDFSClient(FS): time.sleep(retry_sleep_second) if ret == 134: raise FSShellCmdAborted(cmd) + return ret, output.splitlines() @_handle_errors() @@ -615,10 +616,12 @@ class HDFSClient(FS): def _is_dir(self, fs_path): cmd = "test -d {}".format(fs_path, redirect_stderr=True) - ret, lines = self._run_cmd(cmd) + ret, lines = self._run_cmd(cmd, retry_times=1) if ret: # other error if self._test_match(lines): + print('raise exception: ') + print('\n'.join(lines)) raise ExecuteError(cmd) return False @@ -682,13 +685,10 @@ class HDFSClient(FS): client = HDFSClient(hadoop_home, configs) ret = client.is_exist("hdfs:/test_hdfs_client") """ - cmd = "ls {} ".format(fs_path) - ret, out = self._run_cmd(cmd, redirect_stderr=True) + cmd = "test -e {} ".format(fs_path) + ret, out = self._run_cmd(cmd, redirect_stderr=True, retry_times=1) if ret != 0: - for l in out: - if "No such file or directory" in l: - return False - raise ExecuteError(cmd) + return False return True @@ -712,7 +712,7 @@ class HDFSClient(FS): self._try_upload(local_dir, dest_dir) # can't retry - def upload(self, local_path, fs_path, multi_processes=1, overwrite=False): + def upload(self, local_path, fs_path, multi_processes=5, overwrite=False): """ Upload the local path to remote HDFS. @@ -766,11 +766,7 @@ class HDFSClient(FS): local = LocalFS() if not local.is_exist(local_path): raise FSFileNotExistsError("{} not exists".format(local_path)) - # upload_dir - if local.is_dir(local_path): - self.upload_dir(local_path, fs_path, overwrite=overwrite) - return - # upload files + all_files = get_local_files(local_path) if not all_files: print("there are nothing need to upload, function exit") @@ -805,7 +801,7 @@ class HDFSClient(FS): raise e # can't retry - def download(self, fs_path, local_path, multi_processes=1, overwrite=False): + def download(self, fs_path, local_path, multi_processes=5, overwrite=False): """ Download remote HDFS path to the local. @@ -962,7 +958,7 @@ class HDFSClient(FS): cmd = "mv {} {}".format(fs_src_path, fs_dst_path) ret = 0 try: - ret, _ = self._run_cmd(cmd) + ret, _ = self._run_cmd(cmd, retry_times=1) if ret != 0: raise ExecuteError(cmd) except Exception as e: @@ -1090,7 +1086,7 @@ class HDFSClient(FS): @_handle_errors() def _try_cat(self, fs_path): cmd = "cat {}".format(fs_path) - ret, output = self._run_cmd(cmd) + ret, output = self._run_cmd(cmd, retry_times=1) if ret != 0: raise ExecuteError(cmd) return output diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index a99bd6649f0..af56556db44 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -591,6 +591,10 @@ class SparseTable(Table): table_proto.table_class = self.table_class table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE table_proto.shard_num = self.shard_num + if table_proto.sparse_table_cache_file_num > len( + get_ps_endpoints(self.context['role_maker'])): + table_proto.sparse_table_cache_file_num = len( + get_ps_endpoints(self.context['role_maker'])) self.common.table_name = self.context['grad_name_to_param_name'][ ctx.origin_varnames()[0]] @@ -914,6 +918,7 @@ class TheOnePSRuntime(RuntimeBase): self.ps_desc_builder = PsDescBuilder(self.context) def _init_all_params(self, scopes, send_ctx, recv_map): + all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue @@ -923,8 +928,11 @@ class TheOnePSRuntime(RuntimeBase): var_names = recv_map[table_id] #print("init params:", idx, table_id, var_names) self._worker.push_dense_params(scope, table_id, var_names) + all_var_names.extend(var_names) + return all_var_names def _pull_all_dense(self, scopes, send_ctx, recv_map): + all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue @@ -934,8 +942,11 @@ class TheOnePSRuntime(RuntimeBase): var_names = recv_map[table_id] #print("pull all dense:", idx, table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names) + all_var_names.extend(var_names) + return all_var_names def _init_params(self, program, scope, send_ctx, recv_map): + all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue @@ -945,8 +956,11 @@ class TheOnePSRuntime(RuntimeBase): var_names = recv_map[table_id] # print("init params:", table_id, var_names) self._worker.push_dense_params(scope, table_id, var_names) + all_var_names.extend(var_names) + return all_var_names def _pull_dense(self, program, scope, send_ctx, recv_map): + all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue @@ -956,6 +970,8 @@ class TheOnePSRuntime(RuntimeBase): var_names = recv_map[table_id] # print("pull dense:", table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names) + all_var_names.extend(var_names) + return all_var_names def _init_worker(self, scopes=None): worker_desc = self.ps_desc_builder.build_worker_desc() @@ -1208,6 +1224,32 @@ class TheOnePSRuntime(RuntimeBase): model_path = os.path.join(dirname, "dnn_plugin") return model_path + def _ps_save_dense_params(self, + executor, + dirname, + scope, + program, + var_names=None): + dense_map = get_the_one_recv_context( + self.context, split_dense_table=self.is_heter_ps_mode) + send_ctx = get_the_one_send_context( + self.context, + split_dense_table=self.is_heter_ps_mode, + use_origin_program=self.is_heter_ps_mode, + ep_list=self.endpoints) + if program is None or len(self.origin_main_programs) == 1: + program = self.origin_main_programs[0] + dense_var_names = self._pull_dense(program, scope, send_ctx, dense_map) + save_var_names = dense_var_names if var_names is None else var_names + vars = [program.global_block().var(i) for i in save_var_names] + import paddle + with paddle.static.scope_guard(scope): + paddle.static.save_vars(executor, + "./", + program, + vars=vars, + filename=dirname) + def _save_sparse_params(self, executor, dirname, context, main_program, mode): distributed_varnames = get_sparse_tablenames(self.origin_main_programs, @@ -1230,49 +1272,9 @@ class TheOnePSRuntime(RuntimeBase): def _save_distributed_persistables(self, executor, dirname, - main_program, - mode=0): - - denses = get_the_one_recv_context( - self.context, - is_dense=True, - split_dense_table=self.is_heter_ps_mode, - use_origin_program=True) - sparses = get_the_one_recv_context( - self.context, - is_dense=False, - split_dense_table=self.is_heter_ps_mode, - use_origin_program=True) - - sparse_varnames = self._save_sparse_params(executor, dirname, sparses, - main_program, mode) - - recv_dense_varnames = [] - for id, names in denses.items(): - recv_dense_varnames.extend(names) - self._communicator.pull_dense(denses) - - saved_varnames = sparse_varnames - - remaining_vars = list( - filter(TheOnePSRuntime.__exclude_vars(saved_varnames), - main_program.list_vars())) - - import paddle - for var in remaining_vars: - # if var.name not in recv_dense_varnames: - # continue - tensor = var.get_value() - paddle.save(tensor, - os.path.join(dirname, var.name), - use_binary_format=True) - - def _ps_inference_save_persistables(self, - executor, - dirname, - main_program=None, - mode=0, - **kwargs): + main_program=None, + mode=0, + **kwargs): """ This function filters out all variables with `persistable==True` from the give `main_program` and then saves these variables to the folder `dirname` @@ -1301,9 +1303,6 @@ class TheOnePSRuntime(RuntimeBase): "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" ) - # Todo(MrChengmo): Save optimizer status - # self._save_distributed_persistables(executor, dirname, main_program, - # mode) self._worker.save_all_model(dirname, mode) def _ps_inference_save_inference_model(self, @@ -1384,14 +1383,8 @@ class TheOnePSRuntime(RuntimeBase): os.path.join(model_path, var.name), use_binary_format=True) - def _save_inference_model(self, *args, **kwargs): - self._ps_inference_save_inference_model(*args, **kwargs) - - def _save_persistables(self, *args, **kwargs): - self._ps_inference_save_persistables(*args, **kwargs) - def _save_cache_model(self, dirname, **kwargs): - mode = kwargs.get("mode", 0) + mode = kwargs.get("mode", 1) table_id = kwargs.get("table_id", 0) self._worker.client_flush() fleet.util.barrier() @@ -1414,6 +1407,12 @@ class TheOnePSRuntime(RuntimeBase): fleet.util.barrier() return feasign_num + def _check_save_pre_patch_done(self): + fleet.util.barrier() + if self.role_maker._is_first_worker(): + self._worker.check_save_pre_patch_done() + fleet.util.barrier() + def _load_sparse_params(self, dirname, context, main_program, mode): distributed_varnames = get_sparse_tablenames(self.origin_main_programs, True) @@ -1469,10 +1468,7 @@ class TheOnePSRuntime(RuntimeBase): filter(TheOnePSRuntime.__exclude_vars(loaded_varnames), main_program.list_vars())) - if dirname.startswith("afs:") or dirname.startswith("hdfs:"): - model_path = "./dnn_plugin" - else: - model_path = os.path.join(dirname, "dnn_plugin") + model_path = self._get_inference_model_path(dirname) import paddle for var in remaining_vars: if var.name not in recv_dense_varnames: @@ -1482,14 +1478,40 @@ class TheOnePSRuntime(RuntimeBase): self._init_params(main_program, scope, send_ctx, dense_map) - def _load_distributed_persistables(self, path, mode): - self._worker.load_model(path, mode) + def _save_one_table(self, table_id, path, mode): + if self.role_maker._is_first_worker(): + self._worker.save_one_model(table_id, path, mode) + fleet.util.barrier() - def load_model(self, path, mode): - if mode == 0 or mode == 3: - self._load_distributed_persistables(path, mode) - else: + def _save_dense_params(self, *args, **kwargs): + if self.role_maker._is_first_worker(): + self._ps_save_dense_params(*args, **kwargs) + fleet.util.barrier() + + def _save_persistables(self, *args, **kwargs): + if self.role_maker._is_first_worker(): + self._save_distributed_persistables(*args, **kwargs) + fleet.util.barrier() + + def _save_inference_model(self, *args, **kwargs): + if self.role_maker._is_first_worker(): + self._ps_inference_save_inference_model(*args, **kwargs) + fleet.util.barrier() + + def _load_one_table(self, table_id, path, mode): + if self.role_maker._is_first_worker(): + self._worker.load_one_table(table_id, path, mode) + fleet.util.barrier() + + def _load_persistables(self, path, mode): + if self.role_maker._is_first_worker(): + self._worker.load_model(path, mode) + fleet.util.barrier() + + def _load_inference_model(self, path, mode): + if self.role_maker._is_first_worker(): self._ps_inference_load_inference_model(path, mode) + fleet.util.barrier() def _shrink(self, threshold=None): if threshold is not None: diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index a33624ee5ee..6702606ae98 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -236,7 +236,8 @@ class TestDistCTR2x2(FleetDistRunnerBase): fleet.save_inference_model(exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) - self.check_model_right(model_dir) + if fleet.is_first_worker(): + self.check_model_right(model_dir) shutil.rmtree(model_dir) def do_dataset_training_queuedataset(self, fleet): @@ -277,7 +278,8 @@ class TestDistCTR2x2(FleetDistRunnerBase): fleet.save_inference_model(exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) - self.check_model_right(model_dir) + if fleet.is_first_worker(): + self.check_model_right(model_dir) shutil.rmtree(model_dir) dirname = os.getenv("SAVE_DIRNAME", None) @@ -327,16 +329,35 @@ class TestDistCTR2x2(FleetDistRunnerBase): fleet.save_inference_model(exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) - self.check_model_right(model_dir) + fleet.load_inference_model(model_dir, mode=0) + if fleet.is_first_worker(): + self.check_model_right(model_dir) shutil.rmtree(model_dir) dirname = os.getenv("SAVE_DIRNAME", None) if dirname: fleet.save_persistables(exe, dirname=dirname) + fleet.load_model(dirname, mode=0) cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None) if cache_dirname: fleet.save_cache_model(cache_dirname) + dense_param_dirname = os.getenv("SAVE_DENSE_PARAM_DIRNAME", None) + if dense_param_dirname: + fleet.save_dense_params(exe, dense_param_dirname, + fluid.global_scope(), + fluid.default_main_program()) + + save_one_table_dirname = os.getenv("SAVE_ONE_TABLE_DIRNAME", None) + if save_one_table_dirname: + fleet.save_one_table(0, save_one_table_dirname, 0) + fleet.load_one_table(0, save_one_table_dirname, 0) + + patch_dirname = os.getenv("SAVE_PATCH_DIRNAME", None) + if patch_dirname: + fleet.save_persistables(exe, patch_dirname, None, 5) + fleet.check_save_pre_patch_done() + if __name__ == "__main__": runtime_main(TestDistCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py index 4ecad3e97c6..eee2ac9e1ab 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py @@ -91,7 +91,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): fleet.save_inference_model(exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) - self.check_model_right(model_dir) + if fleet.is_first_worker(): + self.check_model_right(model_dir) if fleet.is_first_worker(): fleet.save_persistables(executor=exe, dirname=model_dir) shutil.rmtree(model_dir) @@ -139,7 +140,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): fleet.save_inference_model(exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) - self.check_model_right(model_dir) + if fleet.is_first_worker(): + self.check_model_right(model_dir) if fleet.is_first_worker(): fleet.save_persistables(executor=exe, dirname=model_dir) shutil.rmtree(model_dir) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 59d6ce70ddc..426af39ca90 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -40,8 +40,15 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase): "http_proxy": "", "CPU_NUM": "2", "LOG_DIRNAME": "/tmp", + "SAVE_DIRNAME": "/tmp/TestDistMnistAsyncInMemoryDataset2x2/model", "SAVE_CACHE_DIRNAME": "/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model", + "SAVE_DENSE_PARAM_DIRNAME": + "/tmp/TestDistMnistAsyncInMemoryDataset2x2/dense_param", + "SAVE_ONE_TABLE_DIRNAME": + "/tmp/TestDistMnistAsyncInMemoryDataset2x2/table_0", + "SAVE_PATCH_DIRNAME": + "/tmp/TestDistMnistAsyncInMemoryDataset2x2/patch_model", "LOG_PREFIX": self.__class__.__name__, } -- GitLab