diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 5aef43263575ec9901b76a8236634e8e42b4bdab..ae5c9504ecb6ee25a2b5fea19dab34588ec8fe82 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -181,7 +181,7 @@ enum TableType { message TableParameter { optional uint64 table_id = 1; optional string table_class = 2; - optional uint64 shard_num = 3; + optional uint64 shard_num = 3 [ default = 1000 ]; optional TableType type = 4; optional TableAccessorParameter accessor = 5; } @@ -190,42 +190,73 @@ message TableAccessorParameter { optional string accessor_class = 1; optional SGDParameter embed_sgd_param = 2; optional SGDParameter embedx_sgd_param = 3; - optional uint32 fea_dim = 4; // for sparse table, this means field size of one - // value; for dense table, this means total value - // num - optional uint32 embedx_dim = 5; // embedx feature size - optional uint32 embedx_threshold = 6; // embedx feature create threshold + optional uint32 fea_dim = 4 [ default = 11 ]; // field size of one value + optional uint32 embedx_dim = 5 [ default = 8 ]; // embedx feature size + optional uint32 embedx_threshold = 6 + [ default = 10 ]; // embedx feature create threshold optional CtrAccessorParameter ctr_accessor_param = 7; + repeated TableAccessorSaveParameter table_accessor_save_param = 8; } // TODO(guanqun): add NaiveSGD/Adam... message SGDParameter { optional string name = 1; - optional SGDRuleParameter adagrad = 2; + optional SparseNaiveSGDRuleParameter naive = 2; + optional SparseAdagradSGDRuleParameter adagrad = 3; + optional SparseAdamSGDParameter adam = 4; } -message SGDRuleParameter { - optional double learning_rate = 1; - optional double initial_g2sum = 2; - optional double initial_range = 3 [ default = 0 ]; +message SparseNaiveSGDRuleParameter { // SparseNaiveSGDRule + optional double learning_rate = 1 [ default = 0.05 ]; + optional double initial_range = 2 [ default = 0.0001 ]; + repeated float weight_bounds = 3; +} + +message + SparseAdagradSGDRuleParameter { // SparseAdaGradSGDRule|StdAdaGradSGDRule + optional double learning_rate = 1 [ default = 0.05 ]; + optional double initial_g2sum = 2 [ default = 3.0 ]; + optional double initial_range = 3 [ default = 0.0001 ]; repeated float weight_bounds = 4; } +message SparseAdamSGDParameter { // SparseAdamSGDRule + optional double learning_rate = 1 [ default = 0.001 ]; + optional double initial_range = 2 [ default = 0.0001 ]; + optional double beta1_decay_rate = 3 [ default = 0.9 ]; + optional double beta2_decay_rate = 4 [ default = 0.999 ]; + optional double ada_epsilon = 5 [ default = 1e-08 ]; + repeated float weight_bounds = 6; +} + message CtrAccessorParameter { - optional float nonclk_coeff = 1; // to calculate show_click_score - optional float click_coeff = 2; // to calculate show_click_score - optional float base_threshold = - 3; // show_click_score > base_threshold, this feature can be saved - optional float delta_threshold = - 4; // delta_score > delta_threshold, this feature can be saved - optional float delta_keep_days = - 5; // unseen_day < delta_keep_days, this feature can be saved - optional float show_click_decay_rate = 6; // show/click will update to - // show/click * - // show_click_decay_rate after a day - optional float delete_threshold = 7; // threshold to shrink a feasign - optional float delete_after_unseen_days = 8; - optional int32 ssd_unseenday_threshold = 9; + optional float nonclk_coeff = 1 + [ default = 0.1 ]; // to calculate show_click_score + optional float click_coeff = 2 + [ default = 1 ]; // to calculate show_click_score + optional float base_threshold = 3 [ + default = 1.5 + ]; // show_click_score > base_threshold, this feature can be saved + optional float delta_threshold = 4 + [ default = + 0.25 ]; // delta_score > delta_threshold, this feature can be saved + optional float delta_keep_days = 5 + [ default = + 16 ]; // unseen_day < delta_keep_days, this feature can be saved + optional float show_click_decay_rate = 6 + [ default = 0.98 ]; // show/click will update to + // show/click * + // show_click_decay_rate after a day + optional float delete_threshold = 7 + [ default = 0.8 ]; // threshold to shrink a feasign + optional float delete_after_unseen_days = 8 [ default = 30 ]; + optional int32 ssd_unseenday_threshold = 9 [ default = 1 ]; +} + +message TableAccessorSaveParameter { + optional uint32 param = 1; + optional string converter = 2; + optional string deconverter = 3; } message FsClientParameter { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index cdbc7bd0cd7440be395b319ab06ee9e18c0a7bec..cc0a5de233c382a20266866146f4c85050e921c5 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -470,12 +470,22 @@ class DistributedStrategy(object): from google.protobuf.descriptor import FieldDescriptor table_param = self.strategy.downpour_table_param - def set_table_config(msg, config_name, configs): + def set_table_config(msg, config_name, configs, index=0): for field in msg.DESCRIPTOR.fields: name = config_name + "." + field.name if field.type == FieldDescriptor.TYPE_MESSAGE: print("message:", name) - set_table_config(getattr(msg, field.name), name, configs) + if field.label == FieldDescriptor.LABEL_REPEATED: + if name + ".num" not in configs: + continue + num = configs[name + ".num"] + print("message num:", name, num) + for i in range(num): + data = getattr(msg, field.name).add() + set_table_config(data, name, configs, i) + else: + set_table_config( + getattr(msg, field.name), name, configs) else: print("not message:", name) if name not in configs: @@ -483,9 +493,15 @@ class DistributedStrategy(object): if field.label == FieldDescriptor.LABEL_REPEATED: getattr(msg, field.name).extend(configs[name]) else: - setattr(msg, field.name, configs[name]) + if type(configs[name]) == list: + setattr(msg, field.name, configs[name][index]) + else: + setattr(msg, field.name, configs[name]) - set_table_config(table_param, "table_parameters", configs) + if not configs: + print("table configs is empty") + else: + set_table_config(table_param, "table_parameters", configs) @property def amp(self): diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 57199b8a1e8cc45afbc60652898e2a147547fc4f..a1e5ef2ba799fce61dada7b280ecbcdbcd4a13ca 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -823,7 +823,7 @@ class Fleet(object): self._runtime_handle._save_persistables(executor, dirname, main_program, mode) - def shrink(self, threshold): + def shrink(self, threshold=None): self._runtime_handle._shrink(threshold) def distributed_optimizer(self, optimizer, strategy=None): diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 81613cc1efdfb05d8d576818b94fae7c6eab7652..1c51e833f53f6028745ebc4a5ccbf56cfcea21f9 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -24,7 +24,6 @@ from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.framework import Variable, Parameter from .runtime_base import RuntimeBase from ..base.private_helper_function import wait_server_ready -import paddle.distributed.fleet as fleet __all__ = [] @@ -53,6 +52,70 @@ def parse_table_class(varname, o_main_program): return "MemorySparseTable" +def get_default_accessor_proto(accessor, varname, o_main_program): + embedding_dim = 0 + for var in o_main_program.list_vars(): + if var.name == varname: + print("var:", var) + print("var.shape:", var.shape) + embedding_dim = var.shape[1] + print("sparse dim:", embedding_dim) + break + + accessor.accessor_class = "CtrCommonAccessor" + accessor.fea_dim = embedding_dim + 2 + accessor.embedx_dim = embedding_dim - 1 + accessor.embedx_threshold = 0 + + ctr_accessor_param = accessor.ctr_accessor_param + ctr_accessor_param.nonclk_coeff = 0.1 + ctr_accessor_param.click_coeff = 1.0 + ctr_accessor_param.base_threshold = 0 + ctr_accessor_param.delta_threshold = 0 + ctr_accessor_param.delta_keep_days = 16 + ctr_accessor_param.show_click_decay_rate = 1 + ctr_accessor_param.delete_threshold = 0 + ctr_accessor_param.delete_after_unseen_days = 30 + ctr_accessor_param.ssd_unseenday_threshold = 1 + + embed_sgd_param = accessor.embed_sgd_param + embed_sgd_param.name = "SparseAdaGradSGDRule" + embed_sgd_param.adagrad.learning_rate = 0.05 + embed_sgd_param.adagrad.initial_g2sum = 3.0 + embed_sgd_param.adagrad.initial_range = 0.0001 + embed_sgd_param.adagrad.weight_bounds.append(-10.0) + embed_sgd_param.adagrad.weight_bounds.append(10.0) + + embedx_sgd_param = accessor.embedx_sgd_param + embedx_sgd_param.name = "SparseAdaGradSGDRule" + embedx_sgd_param.adagrad.learning_rate = 0.05 + embedx_sgd_param.adagrad.initial_g2sum = 3.0 + embedx_sgd_param.adagrad.initial_range = 0.0001 + embedx_sgd_param.adagrad.weight_bounds.append(-10.0) + embedx_sgd_param.adagrad.weight_bounds.append(10.0) + + +def check_embedding_dim(accessor, varname, o_main_program): + embedding_dim = 0 + for var in o_main_program.list_vars(): + if var.name == varname: + print("var:", var) + print("var.shape:", var.shape) + embedding_dim = var.shape[1] + print("sparse dim:", embedding_dim) + break + fea_dim = accessor.fea_dim + if fea_dim != embedding_dim + 2: + raise ValueError( + "The fea_dim is wrong, it will be sparse_embedding_dim + 2: {}, but got {}". + format(embedding_dim + 2, fea_dim)) + embedx_dim = accessor.embedx_dim + if embedx_dim != embedding_dim - 1: + raise ValueError( + "The embedx_dim is wrong, it will be sparse_embedding_dim - 1: {}, but got {}". + format(embedding_dim - 1, embedx_dim)) + + class Accessor: def __init__(self): self.accessor_class = "" @@ -344,6 +407,11 @@ class Table: self.accessor_proto = None def to_string(self, indent): + # if self.id == 1: + # proto_txt = '' + # with open('./sparse_table.prototxt') as f: + # proto_txt = f.read() + # return proto_txt table_str = "{}downpour_table_param {{{}\n{}}}" attrs = "" @@ -586,6 +654,8 @@ class TheOnePSRuntime(RuntimeBase): return kwargs proto_txt = str(worker) + "\n" + str(server) + with open('proto_txt', 'w') as f: + f.write(proto_txt) debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) @@ -847,54 +917,54 @@ class TheOnePSRuntime(RuntimeBase): if self.compiled_strategy.is_geo_mode(): table.table_class = "SparseGeoTable" else: - table.table_class = parse_table_class( - common.table_name, self.origin_main_program) - table_proto = self.context[ - "user_defined_strategy"].sparse_table_configs - table.shard_num = table_proto.shard_num + import copy + table_proto = copy.deepcopy(self.context[ + "user_defined_strategy"].sparse_table_configs) + print('table proto:', table_proto) + print('table_class:', table_proto.table_class) + print('shard_num:', table_proto.shard_num) + print('table_proto.accessor:', table_proto.accessor) + print('accessor.IsInitialized', + table_proto.accessor.IsInitialized()) + print('accessor.ByteSize', + table_proto.accessor.ByteSize()) + if table_proto.table_class: + print('table_proto.table_class is true') + table.table_class = table_proto.table_class + else: + table.table_class = parse_table_class( + common.table_name, self.origin_main_program) + if table.table_class != 'MemorySparseTable': + table.table_class = 'MemorySparseTable' + warnings.warn( + "The PS mode must use MemorySparseTable.") + + if table_proto.shard_num: + print('table_proto.shard_num is true') + table.shard_num = table_proto.shard_num + else: + table.shard_num = 1000 + warnings.warn( + "The shard_num of sparse table is not set, use default value 1000." + ) + + if table_proto.accessor.ByteSize() == 0: + print('table_proto.accessor is false') + get_default_accessor_proto(table_proto.accessor, + common.table_name, + self.origin_main_program) + warnings.warn( + "The accessor of sparse table is not set, use default value." + ) + check_embedding_dim(table_proto.accessor, + common.table_name, + self.origin_main_program) + print('accessor.ByteSize', + table_proto.accessor.ByteSize()) from google.protobuf import text_format table.accessor_proto = text_format.MessageToString( table_proto.accessor) - - print('table proto:', table_proto) - if table.table_class == 'MemorySparseTable' and table.accessor_proto == '': - emb_dim = ctx.sections()[1] - table.shard_num = 1950 - table.accessor_proto = 'accessor_class: "CtrCommonAccessor"\n' \ - 'embed_sgd_param {\n' \ - ' name: "SparseAdaGradSGDRule"\n' \ - ' adagrad {\n' \ - ' learning_rate: 0.05\n' \ - ' initial_g2sum: 3.0\n' \ - ' initial_range: 0.0001\n' \ - ' weight_bounds: -10.0\n' \ - ' weight_bounds: 10.0\n' \ - ' }\n' \ - '}\n' \ - 'embedx_sgd_param {\n' \ - ' name: "SparseAdaGradSGDRule"\n' \ - ' adagrad {\n' \ - ' learning_rate: 0.05\n' \ - ' initial_g2sum: 3.0\n' \ - ' initial_range: 0.0001\n' \ - ' weight_bounds: -10.0\n' \ - ' weight_bounds: 10.0\n' \ - ' }\n' \ - '}\n' \ - 'fea_dim: ' + str(emb_dim+2) + '\n' \ - 'embedx_dim: ' + str(emb_dim-1) + '\n' \ - 'embedx_threshold: 10\n' \ - 'ctr_accessor_param {\n' \ - ' nonclk_coeff: 0.1\n' \ - ' click_coeff: 1.0\n' \ - ' base_threshold: 1.5\n' \ - ' delta_threshold: 0.25\n' \ - ' delta_keep_days: 16.0\n' \ - ' show_click_decay_rate: 0.98\n' \ - ' delete_threshold: 0.8\n' \ - ' delete_after_unseen_days: 30.0\n' \ - ' ssd_unseenday_threshold: 1\n' \ - '}' + print("the_one_ps table_proto:", table.accessor_proto) else: table.type = "PS_DENSE_TABLE" table.table_class = "CommonDenseTable" @@ -916,7 +986,6 @@ class TheOnePSRuntime(RuntimeBase): common.sync = "true" else: common.sync = "false" - table.common = common if table.table_class != 'MemorySparseTable': @@ -1108,8 +1177,6 @@ class TheOnePSRuntime(RuntimeBase): TheOnePSRuntime.__exclude_vars(saved_varnames), main_program.list_vars())) - self._communicator.pull_dense(denses) - import paddle for var in remaining_vars: # if var.name not in recv_dense_varnames: @@ -1209,9 +1276,8 @@ class TheOnePSRuntime(RuntimeBase): split_dense_table=self.role_maker._is_heter_parameter_server_mode, use_origin_program=True) print("the one ps sparses:", sparses) - sparse_names = [] - for id, name in sparses.items(): - sparse_names.extend(name) + sparse_names = self._save_sparse_params(executor, dirname, sparses, + main_program, mode) print("the one ps sparse names:", sparse_names) denses = self.compiled_strategy.get_the_one_recv_context( @@ -1225,7 +1291,7 @@ class TheOnePSRuntime(RuntimeBase): generate_vars = [var for var in generate_vars] remaining_vars = list( filter( - TheOnePSRuntime.__exclude_vars(generate_vars + sparse_names), + TheOnePSRuntime.__exclude_vars(sparse_names), infer_program.list_vars())) print("remain_vars:", [var.name for var in remaining_vars]) for var in remaining_vars: @@ -1235,9 +1301,6 @@ class TheOnePSRuntime(RuntimeBase): os.path.join(model_path, var.name), use_binary_format=True) - self._ps_inference_save_persistables(executor, dirname, infer_program, - mode) - def _save_inference_model(self, *args, **kwargs): self._ps_inference_save_inference_model(*args, **kwargs) @@ -1314,8 +1377,15 @@ class TheOnePSRuntime(RuntimeBase): self._load_distributed_persistables(path, mode) else: self._ps_inference_load_inference_model(path, mode) + # self._load_distributed_persistables(path, mode=mode) - def _shrink(self, threshold): + def _shrink(self, threshold=None): + if threshold is not None: + warnings.warn( + "The param threshold is not used in MemorySparseTable, if you need to shrink, please set the config of accessor" + ) + else: + threshold = 0 import paddle.distributed.fleet as fleet fleet.util.barrier() if self.role_maker._is_first_worker(): diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 9f698134eee17408222fb010122b50a39f42bc25..0e291648b37544c3f3bb8cb29364fb41cfeb5afc 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -862,8 +862,12 @@ class InMemoryDataset(DatasetBase): thread_num(int): shuffle thread num. Default is 12. """ + from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib if fleet is not None: - fleet._role_maker.barrier_worker() + if not isinstance(fleet, PSLib): + fleet.barrier_worker() + else: + fleet._role_maker.barrier_worker() if self.trainer_num == -1: self.trainer_num = fleet.worker_num() if self.fleet_send_batch_size is None: @@ -875,14 +879,23 @@ class InMemoryDataset(DatasetBase): self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds) if fleet is not None: - fleet._role_maker.barrier_worker() + if not isinstance(fleet, PSLib): + fleet.barrier_worker() + else: + fleet._role_maker.barrier_worker() self.dataset.global_shuffle(thread_num) if fleet is not None: - fleet._role_maker.barrier_worker() + if not isinstance(fleet, PSLib): + fleet.barrier_worker() + else: + fleet._role_maker.barrier_worker() if self.merge_by_lineid: self.dataset.merge_by_lineid() if fleet is not None: - fleet._role_maker.barrier_worker() + if not isinstance(fleet, PSLib): + fleet.barrier_worker() + else: + fleet._role_maker.barrier_worker() @deprecated( since="2.0.0", @@ -1011,10 +1024,15 @@ class InMemoryDataset(DatasetBase): import numpy as np local_data_size = self.dataset.get_shuffle_data_size() local_data_size = np.array([local_data_size]) + print('global shuffle local_data_size: ', local_data_size) if fleet is not None: + from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib global_data_size = local_data_size * 0 - fleet._role_maker.all_reduce_worker(local_data_size, - global_data_size) + if not isinstance(fleet, PSLib): + global_data_size = fleet.util.all_reduce(local_data_size) + else: + fleet._role_maker.all_reduce_worker(local_data_size, + global_data_size) return global_data_size[0] return local_data_size[0] diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 65c8a7500f246cbd3d7a48d94d5425fe4ba1ad78..2bd397b0ef3f531a30ac45288689d0897a310b23 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -241,7 +241,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): self.check_model_right(model_dir) shutil.rmtree(model_dir) - def do_dataset_training(self, fleet): + def do_dataset_training_queuedataset(self, fleet): train_file_list = ctr_dataset_reader.prepare_fake_data() exe = self.get_executor() @@ -288,5 +288,56 @@ class TestDistCTR2x2(FleetDistRunnerBase): if dirname: fleet.save_persistables(exe, dirname=dirname) + def do_dataset_training(self, fleet): + train_file_list = ctr_dataset_reader.prepare_fake_data() + + exe = self.get_executor() + exe.run(fluid.default_startup_program()) + fleet.init_worker() + + thread_num = 2 + batch_size = 128 + filelist = train_file_list + + # config dataset + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_use_var(self.feeds) + dataset.set_batch_size(128) + dataset.set_thread(2) + dataset.set_filelist(filelist) + dataset.set_pipe_command('python ctr_dataset_reader.py') + dataset.load_into_memory() + + dataset.global_shuffle(fleet, 12) ##TODO: thread configure + shuffle_data_size = dataset.get_shuffle_data_size(fleet) + local_data_size = dataset.get_shuffle_data_size() + data_size_list = fleet.util.all_gather(local_data_size) + print('after global_shuffle data_size_list: ', data_size_list) + print('after global_shuffle data_size: ', shuffle_data_size) + + for epoch_id in range(1): + pass_start = time.time() + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=dataset, + fetch_list=[self.avg_cost], + fetch_info=["cost"], + print_period=2, + debug=int(os.getenv("Debug", "0"))) + pass_time = time.time() - pass_start + dataset.release_memory() + + if os.getenv("SAVE_MODEL") == "1": + model_dir = tempfile.mkdtemp() + fleet.save_inference_model(exe, model_dir, + [feed.name for feed in self.feeds], + self.avg_cost) + self.check_model_right(model_dir) + shutil.rmtree(model_dir) + + dirname = os.getenv("SAVE_DIRNAME", None) + if dirname: + fleet.save_persistables(exe, dirname=dirname) + if __name__ == "__main__": runtime_main(TestDistCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 3beb1d3dfe0331d09961da7c64ee95987fe025a7..59d196fdf55e57b3175b3deb6036f4b88b565d34 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -20,6 +20,41 @@ import tempfile from test_dist_fleet_base import TestFleetBase +class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase): + def _setup_config(self): + self._mode = "async" + #self._reader = "pyreader" + self._reader = "dataset" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "", + "CPU_NUM": "2", + "LOG_DIRNAME": "/tmp", + "LOG_PREFIX": self.__class__.__name__, + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_ctr.py", delta=1e-5, check_error_log=False) + + class TestDistMnistAsync2x2(TestFleetBase): def _setup_config(self): self._mode = "async" diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 9cf3eb251b3962ef76d2cc0af65e1c2fb5429372..a9193c0abdfc18785d5fa4fad5e21cbbe95a10a7 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -259,11 +259,18 @@ class TestStrategyConfig(unittest.TestCase): strategy = paddle.distributed.fleet.DistributedStrategy() configs = { "table_parameters.accessor.embed_sgd_param.adagrad.learning_rate": - 0.05 + 0.05, + "table_parameters.accessor.table_accessor_save_param.num": 2, + "table_parameters.accessor.table_accessor_save_param.param": + [1, 2] } strategy.sparse_table_configs = configs self.assertEqual(strategy.sparse_table_configs.accessor.embed_sgd_param. adagrad.learning_rate, 0.05) + self.assertEqual( + strategy.sparse_table_configs.accessor.table_accessor_save_param[ + 0].param, 1) + strategy.adam_d2sum = True self.assertEqual(strategy.adam_d2sum, True) strategy.fs_client_param = {