diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 739e05e1d79712a6551c4b97f5e034b0f93ee1b8..7380e0f129cf46bc00f35a9eed72e6dd492a6eab 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -180,10 +180,11 @@ enum TableType { message TableParameter { optional uint64 table_id = 1; - optional string table_class = 2; - optional uint64 shard_num = 3 [ default = 1000 ]; - optional TableType type = 4; - optional TableAccessorParameter accessor = 5; + optional string table_name = 2; + optional string table_class = 3; + optional uint64 shard_num = 4 [ default = 1000 ]; + optional TableType type = 5; + optional TableAccessorParameter accessor = 6; } message TableAccessorParameter { @@ -198,7 +199,6 @@ message TableAccessorParameter { repeated TableAccessorSaveParameter table_accessor_save_param = 8; } -// TODO(guanqun): add NaiveSGD/Adam... message SGDParameter { optional string name = 1; optional SparseNaiveSGDRuleParameter naive = 2; @@ -321,7 +321,7 @@ message DistributedStrategy { optional HybridConfig hybrid_configs = 112; optional TensorParallelConfig tensor_parallel_configs = 113; optional TrainerDescConfig trainer_desc_configs = 114; - optional TableParameter downpour_table_param = 115; + repeated TableParameter downpour_table_param = 115; optional FsClientParameter fs_client_param = 116; optional BuildStrategy build_strategy = 201; diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index e58b6c312fa1fe28b1565cc4421177c2c139cf21..3b8b36a61e2fb73de67d40ea981d9eb1f4455b00 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -474,12 +474,12 @@ class DistributedStrategy(object): for field in msg.DESCRIPTOR.fields: name = config_name + "." + field.name if field.type == FieldDescriptor.TYPE_MESSAGE: - print("message:", name) + # print("message:", name) if field.label == FieldDescriptor.LABEL_REPEATED: if name + ".num" not in configs: continue num = configs[name + ".num"] - print("message num:", name, num) + # print("message num:", name, num) for i in range(num): data = getattr(msg, field.name).add() set_table_config(data, name, configs, i) @@ -487,7 +487,7 @@ class DistributedStrategy(object): set_table_config( getattr(msg, field.name), name, configs) else: - print("not message:", name) + # print("not message:", name) if name not in configs: continue if field.label == FieldDescriptor.LABEL_REPEATED: @@ -501,7 +501,11 @@ class DistributedStrategy(object): if not configs: print("table configs is empty") else: - set_table_config(table_param, "table_parameters", configs) + for table_name in configs: + table_data = table_param.add() + table_data.table_name = table_name + set_table_config(table_data, "table_parameters." + table_name, + configs[table_name]) @property def amp(self): diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 1c51e833f53f6028745ebc4a5ccbf56cfcea21f9..1240e1492a7840bd5e66ee2899325b700c658693 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -56,53 +56,77 @@ def get_default_accessor_proto(accessor, varname, o_main_program): embedding_dim = 0 for var in o_main_program.list_vars(): if var.name == varname: - print("var:", var) - print("var.shape:", var.shape) embedding_dim = var.shape[1] - print("sparse dim:", embedding_dim) break - accessor.accessor_class = "CtrCommonAccessor" - accessor.fea_dim = embedding_dim + 2 - accessor.embedx_dim = embedding_dim - 1 - accessor.embedx_threshold = 0 + if not accessor.HasField("accessor_class"): + accessor.accessor_class = "CtrCommonAccessor" + if not accessor.HasField("fea_dim"): + accessor.fea_dim = embedding_dim + 2 + if not accessor.HasField("embedx_dim"): + accessor.embedx_dim = embedding_dim - 1 + if not accessor.HasField("embedx_threshold"): + accessor.embedx_threshold = 0 ctr_accessor_param = accessor.ctr_accessor_param - ctr_accessor_param.nonclk_coeff = 0.1 - ctr_accessor_param.click_coeff = 1.0 - ctr_accessor_param.base_threshold = 0 - ctr_accessor_param.delta_threshold = 0 - ctr_accessor_param.delta_keep_days = 16 - ctr_accessor_param.show_click_decay_rate = 1 - ctr_accessor_param.delete_threshold = 0 - ctr_accessor_param.delete_after_unseen_days = 30 - ctr_accessor_param.ssd_unseenday_threshold = 1 - - embed_sgd_param = accessor.embed_sgd_param - embed_sgd_param.name = "SparseAdaGradSGDRule" - embed_sgd_param.adagrad.learning_rate = 0.05 - embed_sgd_param.adagrad.initial_g2sum = 3.0 - embed_sgd_param.adagrad.initial_range = 0.0001 - embed_sgd_param.adagrad.weight_bounds.append(-10.0) - embed_sgd_param.adagrad.weight_bounds.append(10.0) - - embedx_sgd_param = accessor.embedx_sgd_param - embedx_sgd_param.name = "SparseAdaGradSGDRule" - embedx_sgd_param.adagrad.learning_rate = 0.05 - embedx_sgd_param.adagrad.initial_g2sum = 3.0 - embedx_sgd_param.adagrad.initial_range = 0.0001 - embedx_sgd_param.adagrad.weight_bounds.append(-10.0) - embedx_sgd_param.adagrad.weight_bounds.append(10.0) + if not ctr_accessor_param.HasField("nonclk_coeff"): + ctr_accessor_param.nonclk_coeff = 0.1 + if not ctr_accessor_param.HasField("click_coeff"): + ctr_accessor_param.click_coeff = 1.0 + if not ctr_accessor_param.HasField("base_threshold"): + ctr_accessor_param.base_threshold = 0 + if not ctr_accessor_param.HasField("delta_threshold"): + ctr_accessor_param.delta_threshold = 0 + if not ctr_accessor_param.HasField("delta_keep_days"): + ctr_accessor_param.delta_keep_days = 16 + if not ctr_accessor_param.HasField("show_click_decay_rate"): + ctr_accessor_param.show_click_decay_rate = 1 + if not ctr_accessor_param.HasField("delete_threshold"): + ctr_accessor_param.delete_threshold = 0 + if not ctr_accessor_param.HasField("delete_after_unseen_days"): + ctr_accessor_param.delete_after_unseen_days = 30 + if not ctr_accessor_param.HasField("ssd_unseenday_threshold"): + ctr_accessor_param.ssd_unseenday_threshold = 1 + + for sgd_param in [accessor.embed_sgd_param, accessor.embedx_sgd_param]: + if not sgd_param.HasField("name"): + sgd_param.name = "SparseAdaGradSGDRule" + if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule": + if not sgd_param.adagrad.HasField("learning_rate"): + sgd_param.adagrad.learning_rate = 0.05 + if not sgd_param.adagrad.HasField("initial_g2sum"): + sgd_param.adagrad.initial_g2sum = 3.0 + if not sgd_param.adagrad.HasField("initial_range"): + sgd_param.adagrad.initial_range = 0.0001 + if len(sgd_param.adagrad.weight_bounds) == 0: + sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0]) + if sgd_param.name == "SparseNaiveSGDRule": + if not sgd_param.naive.HasField("learning_rate"): + sgd_param.naive.learning_rate = 0.05 + if not sgd_param.naive.HasField("initial_range"): + sgd_param.naive.initial_range = 0.0001 + if len(sgd_param.naive.weight_bounds) == 0: + sgd_param.naive.weight_bounds.extend([-10.0, 10.0]) + if sgd_param.name == "SparseAdamSGDRule": + if not sgd_param.adam.HasField("learning_rate"): + sgd_param.adam.learning_rate = 0.001 + if not sgd_param.adam.HasField("initial_range"): + sgd_param.adam.initial_range = 0.0001 + if not sgd_param.adam.HasField("beta1_decay_rate"): + sgd_param.adam.beta1_decay_rate = 0.9 + if not sgd_param.adam.HasField("beta2_decay_rate"): + sgd_param.adam.beta2_decay_rate = 0.999 + if not sgd_param.adam.HasField("ada_epsilon"): + sgd_param.adam.ada_epsilon = 1e-08 + if len(sgd_param.adam.weight_bounds) == 0: + sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) def check_embedding_dim(accessor, varname, o_main_program): embedding_dim = 0 for var in o_main_program.list_vars(): if var.name == varname: - print("var:", var) - print("var.shape:", var.shape) embedding_dim = var.shape[1] - print("sparse dim:", embedding_dim) break fea_dim = accessor.fea_dim if fea_dim != embedding_dim + 2: @@ -917,19 +941,14 @@ class TheOnePSRuntime(RuntimeBase): if self.compiled_strategy.is_geo_mode(): table.table_class = "SparseGeoTable" else: - import copy - table_proto = copy.deepcopy(self.context[ - "user_defined_strategy"].sparse_table_configs) - print('table proto:', table_proto) - print('table_class:', table_proto.table_class) - print('shard_num:', table_proto.shard_num) - print('table_proto.accessor:', table_proto.accessor) - print('accessor.IsInitialized', - table_proto.accessor.IsInitialized()) - print('accessor.ByteSize', - table_proto.accessor.ByteSize()) - if table_proto.table_class: - print('table_proto.table_class is true') + all_table_proto = self.context[ + "user_defined_strategy"].sparse_table_configs + table_proto = all_table_proto.add() + for proto in all_table_proto: + if proto.table_name == common.table_name: + table_proto = proto + break + if table_proto.HasField("table_class"): table.table_class = table_proto.table_class else: table.table_class = parse_table_class( @@ -939,8 +958,7 @@ class TheOnePSRuntime(RuntimeBase): warnings.warn( "The PS mode must use MemorySparseTable.") - if table_proto.shard_num: - print('table_proto.shard_num is true') + if table_proto.HasField("shard_num"): table.shard_num = table_proto.shard_num else: table.shard_num = 1000 @@ -949,22 +967,18 @@ class TheOnePSRuntime(RuntimeBase): ) if table_proto.accessor.ByteSize() == 0: - print('table_proto.accessor is false') - get_default_accessor_proto(table_proto.accessor, - common.table_name, - self.origin_main_program) warnings.warn( "The accessor of sparse table is not set, use default value." ) + get_default_accessor_proto(table_proto.accessor, + common.table_name, + self.origin_main_program) check_embedding_dim(table_proto.accessor, common.table_name, self.origin_main_program) - print('accessor.ByteSize', - table_proto.accessor.ByteSize()) from google.protobuf import text_format table.accessor_proto = text_format.MessageToString( table_proto.accessor) - print("the_one_ps table_proto:", table.accessor_proto) else: table.type = "PS_DENSE_TABLE" table.table_class = "CommonDenseTable" @@ -1275,10 +1289,8 @@ class TheOnePSRuntime(RuntimeBase): is_dense=False, split_dense_table=self.role_maker._is_heter_parameter_server_mode, use_origin_program=True) - print("the one ps sparses:", sparses) sparse_names = self._save_sparse_params(executor, dirname, sparses, main_program, mode) - print("the one ps sparse names:", sparse_names) denses = self.compiled_strategy.get_the_one_recv_context( is_dense=True, @@ -1293,7 +1305,7 @@ class TheOnePSRuntime(RuntimeBase): filter( TheOnePSRuntime.__exclude_vars(sparse_names), infer_program.list_vars())) - print("remain_vars:", [var.name for var in remaining_vars]) + for var in remaining_vars: tensor = var.get_value() paddle.save( diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py index ccbe154a4875334f98443256b43c027bcaf3a5a4..4e3dfccee28a2895bf1c0f4f83220dfd0349be5a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py @@ -167,6 +167,15 @@ class TestPSPassWithBow(unittest.TestCase): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.a_sync = True + + configs = {} + configs['__emb__'] = { + "table_parameters.__emb__.accessor.embed_sgd_param.name": + "SparseNaiveSGDRule", + "table_parameters.__emb__.accessor.embedx_sgd_param.name": + "SparseAdamSGDRule", + } + strategy.sparse_table_configs = configs optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer.minimize(loss) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index a9193c0abdfc18785d5fa4fad5e21cbbe95a10a7..7d611ed6e06d4569ab62db74e1adcf12422a9682 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -257,19 +257,19 @@ class TestStrategyConfig(unittest.TestCase): def test_sparse_table_configs(self): strategy = paddle.distributed.fleet.DistributedStrategy() - configs = { - "table_parameters.accessor.embed_sgd_param.adagrad.learning_rate": + configs = {} + configs['emb'] = { + "table_parameters.emb.accessor.embed_sgd_param.adagrad.learning_rate": 0.05, - "table_parameters.accessor.table_accessor_save_param.num": 2, - "table_parameters.accessor.table_accessor_save_param.param": + "table_parameters.emb.accessor.table_accessor_save_param.num": 2, + "table_parameters.emb.accessor.table_accessor_save_param.param": [1, 2] } strategy.sparse_table_configs = configs - self.assertEqual(strategy.sparse_table_configs.accessor.embed_sgd_param. - adagrad.learning_rate, 0.05) - self.assertEqual( - strategy.sparse_table_configs.accessor.table_accessor_save_param[ - 0].param, 1) + self.assertEqual(strategy.sparse_table_configs[0] + .accessor.embed_sgd_param.adagrad.learning_rate, 0.05) + self.assertEqual(strategy.sparse_table_configs[0] + .accessor.table_accessor_save_param[0].param, 1) strategy.adam_d2sum = True self.assertEqual(strategy.adam_d2sum, True)