From ae21737343115aec3c48246e1950c345f26b955d Mon Sep 17 00:00:00 2001 From: wangguanqun Date: Thu, 1 Sep 2022 11:33:13 +0800 Subject: [PATCH] ps optimizer default config (#45563) * config * fix unittest * zero init & cache & patch config * add barrier to save and load * add unittest --- .../distributed/ps/table/ctr_accessor.cc | 4 +- .../ps/table/ctr_double_accessor.cc | 7 +- .../distributed/ps/table/sparse_accessor.cc | 4 +- paddle/fluid/distributed/the_one_ps.proto | 3 + .../framework/distributed_strategy.proto | 7 ++ .../fleet/base/distributed_strategy.py | 6 ++ python/paddle/distributed/ps/the_one_ps.py | 99 +++++++++++++++---- python/paddle/distributed/ps/utils/public.py | 9 ++ .../tests/unittests/test_dist_fleet_ps2.py | 5 + 9 files changed, 121 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_accessor.cc index cde42fc0e61..61e748a5413 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.cc @@ -177,8 +177,10 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) { value[common_feature_value.ShowIndex()] = 0; value[common_feature_value.ClickIndex()] = 0; value[common_feature_value.SlotIndex()] = -1; + bool zero_init = _config.ctr_accessor_param().zero_init(); _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(), - value + common_feature_value.EmbedG2SumIndex()); + value + common_feature_value.EmbedG2SumIndex(), + zero_init); _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), value + common_feature_value.EmbedxG2SumIndex(), false); diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc index 9dfcafa8a9f..2573b9db06a 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc @@ -176,9 +176,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) { *reinterpret_cast(value + CtrDoubleFeatureValue::ShowIndex()) = 0; *(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0; value[CtrDoubleFeatureValue::SlotIndex()] = -1; - _embed_sgd_rule->InitValue( - value + CtrDoubleFeatureValue::EmbedWIndex(), - value + CtrDoubleFeatureValue::EmbedG2SumIndex()); + bool zero_init = _config.ctr_accessor_param().zero_init(); + _embed_sgd_rule->InitValue(value + CtrDoubleFeatureValue::EmbedWIndex(), + value + CtrDoubleFeatureValue::EmbedG2SumIndex(), + zero_init); _embedx_sgd_rule->InitValue( value + CtrDoubleFeatureValue::EmbedxWIndex(), value + CtrDoubleFeatureValue::EmbedxG2SumIndex(), diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.cc b/paddle/fluid/distributed/ps/table/sparse_accessor.cc index 1591e340b9e..afa94703233 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.cc +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.cc @@ -150,8 +150,10 @@ int32_t SparseAccessor::Create(float** values, size_t num) { value[sparse_feature_value.ShowIndex()] = 0; value[sparse_feature_value.ClickIndex()] = 0; value[sparse_feature_value.SlotIndex()] = -1; + bool zero_init = _config.ctr_accessor_param().zero_init(); _embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(), - value + sparse_feature_value.EmbedG2SumIndex()); + value + sparse_feature_value.EmbedG2SumIndex(), + zero_init); _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(), value + sparse_feature_value.EmbedxG2SumIndex(), false); diff --git a/paddle/fluid/distributed/the_one_ps.proto b/paddle/fluid/distributed/the_one_ps.proto index 2241655465f..5eeba703360 100755 --- a/paddle/fluid/distributed/the_one_ps.proto +++ b/paddle/fluid/distributed/the_one_ps.proto @@ -120,6 +120,7 @@ message TableParameter { optional bool enable_sparse_table_cache = 10 [ default = true ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; + // for patch model optional bool enable_revert = 13 [ default = false ]; optional float shard_merge_rate = 14 [ default = 1.0 ]; } @@ -167,6 +168,7 @@ message CtrAccessorParameter { optional int32 ssd_unseenday_threshold = 9 [ default = 1 ]; // threshold to save ssd optional bool show_scale = 10 [ default = true ]; + optional bool zero_init = 11 [ default = true ]; } message TensorAccessorParameter { @@ -189,6 +191,7 @@ message CommonAccessorParameter { optional bool sync = 9; optional uint32 table_num = 10; optional uint32 table_dim = 11; + optional string attr = 12; } message TableAccessorSaveParameter { diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 474630946b2..7c02c9bab73 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -205,6 +205,13 @@ message TableParameter { optional TableType type = 5; optional TableAccessorParameter accessor = 6; optional bool compress_in_save = 7 [ default = false ]; + // for cache model + optional bool enable_sparse_table_cache = 10 [ default = true ]; + optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; + optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; + // for patch model + optional bool enable_revert = 13 [ default = false ]; + optional float shard_merge_rate = 14 [ default = 1.0 ]; } message TableAccessorParameter { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 765fca275d7..b83d97d1d35 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -627,6 +627,12 @@ class DistributedStrategy(object): % (table_class)) table_data.table_class = 'MemorySparseTable' table_data.shard_num = config.get('sparse_shard_num', 1000) + table_data.enable_sparse_table_cache = config.get( + 'sparse_enable_cache', True) + table_data.sparse_table_cache_rate = config.get( + 'sparse_cache_rate', 0.00055) + table_data.sparse_table_cache_file_num = config.get( + 'sparse_cache_file_num', 16) accessor_class = config.get("sparse_accessor_class", "DownpourCtrAccessor") diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index ea82f30cf8e..77a0ab0a659 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -127,7 +127,8 @@ class Accessor: self.embedding_dim = 0 # TableAccessorParameter accessor - def _set(self, accessor_proto, varname, program_id, context): + def _set(self, accessor_proto, varname, program_id, context, + common_accessor): main_program, startup_program, idx = get_program_by_id( context, program_id) embedding_dim = 0 @@ -162,6 +163,8 @@ class Accessor: graph_sgd_param.feature_learning_rate = 0.05 ctr_accessor_param = accessor_proto.ctr_accessor_param + if accessor_proto.embedx_dim == 0: + ctr_accessor_param.zero_init = False if not ctr_accessor_param.HasField("nonclk_coeff"): ctr_accessor_param.nonclk_coeff = 0.1 if not ctr_accessor_param.HasField("click_coeff"): @@ -185,7 +188,11 @@ class Accessor: accessor_proto.embed_sgd_param, accessor_proto.embedx_sgd_param ]: if not sgd_param.HasField("name"): - sgd_param.name = "SparseAdaGradSGDRule" + if common_accessor.accessor_class == "sgd": + sgd_param.name = "SparseNaiveSGDRule" + if common_accessor.accessor_class == "adam": + sgd_param.name = "SparseAdamSGDRule" + if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule": if not sgd_param.adagrad.HasField("learning_rate"): sgd_param.adagrad.learning_rate = 0.05 @@ -195,23 +202,47 @@ class Accessor: sgd_param.adagrad.initial_range = 0.0001 if len(sgd_param.adagrad.weight_bounds) == 0: sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0]) + if sgd_param.name == "SparseNaiveSGDRule": if not sgd_param.naive.HasField("learning_rate"): - sgd_param.naive.learning_rate = 0.05 + learning_rate = common_accessor.initializers[-1].split( + "&")[1] + sgd_param.naive.learning_rate = float(learning_rate) if not sgd_param.naive.HasField("initial_range"): - sgd_param.naive.initial_range = 0.0001 + initial_range = common_accessor.initializers[0].split( + "&")[-1] + sgd_param.naive.initial_range = float(initial_range) if len(sgd_param.naive.weight_bounds) == 0: sgd_param.naive.weight_bounds.extend([-10.0, 10.0]) + if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule": if not sgd_param.adam.HasField("learning_rate"): - sgd_param.adam.learning_rate = 0.001 + learning_rate = common_accessor.initializers[-1].split( + "&")[1] + sgd_param.adam.learning_rate = float(learning_rate) if not sgd_param.adam.HasField("initial_range"): - sgd_param.adam.initial_range = 0.0001 - if not sgd_param.adam.HasField("beta1_decay_rate"): + initial_range = common_accessor.initializers[0].split( + "&")[-1] + sgd_param.adam.initial_range = float(initial_range) + + attr_list = [x.split("&") for x in common_accessor.attrs] + if not sgd_param.adam.HasField( + "beta1_decay_rate" + ) and common_accessor.accessor_class == "adam": + sgd_param.adam.beta1_decay_rate = float(attr_list[0][1]) + else: sgd_param.adam.beta1_decay_rate = 0.9 - if not sgd_param.adam.HasField("beta2_decay_rate"): + if not sgd_param.adam.HasField( + "beta2_decay_rate" + ) and common_accessor.accessor_class == "adam": + sgd_param.adam.beta2_decay_rate = float(attr_list[1][1]) + else: sgd_param.adam.beta2_decay_rate = 0.999 - if not sgd_param.adam.HasField("ada_epsilon"): + if not sgd_param.adam.HasField( + "ada_epsilon" + ) and common_accessor.accessor_class == "adam": + sgd_param.adam.ada_epsilon = float(attr_list[2][1]) + else: sgd_param.adam.ada_epsilon = 1e-08 if len(sgd_param.adam.weight_bounds) == 0: sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) @@ -258,7 +289,7 @@ class CommonAccessor(Accessor): ("epsilon", "f")] opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), ("epsilon", "f")] - opt_attr_map["summary"] = [] + opt_attr_map["summary"] = [("summary_decay_rate", "f")] opt_init_map = {} opt_init_map["gaussian_random"] = ["seed", "mean", "std"] @@ -375,6 +406,9 @@ class CommonAccessor(Accessor): attr_varnames = self.opt_attr_map["adam_d2sum"] self.accessor_class = "adam_d2sum" else: + if oop.type != 'sgd' and oop.type != 'adam': + raise ValueError( + "The dense optimizer in PS is only supported SGD or Adam!") param_varnames = self.opt_input_map[oop.type] attr_varnames = self.opt_attr_map[oop.type] self.accessor_class = oop.type @@ -459,9 +493,18 @@ class CommonAccessor(Accessor): param.name, startup_program) initializers.append(initializer) + if self.accessor_class == 'summary': + datanorm_ops = get_datanorm_ops(main_program) + for op in datanorm_ops: + if ("BatchSize" in op.input_names) and ( + op.input("BatchSize")[0] + == context['grad_name_to_param_name'][grad_name]): + oop = op + break + for (attr_varname, type_) in attr_varnames: value = oop.attr(attr_varname) - attrs.append("&".join([attr_varname, type_, str(value)])) + attrs.append("&".join([attr_varname, str(value)])) self.params = params self.dims = dims @@ -480,6 +523,7 @@ class CommonAccessor(Accessor): proto.sync = self.sync proto.table_num = self.table_num proto.table_dim = self.table_dim + proto.attr = "#".join(self.attrs) class Tensor: @@ -599,6 +643,13 @@ class SparseTable(Table): self.common.table_name = self.context['grad_name_to_param_name'][ ctx.origin_varnames()[0]] + self.common.parse_by_optimizer(ctx, self.context) + self.common.parse_entry(self.common.table_name, ctx.program_id(), + self.context) + self.common.sync = True if self.context['is_sync'] else False + + self.common._set(table_proto.common) + print('new table_name: {}'.format(self.common.table_name)) all_table_proto = self.context[ "user_defined_strategy"].sparse_table_configs @@ -626,6 +677,17 @@ class SparseTable(Table): "The shard_num of sparse table is not set, use default value 1000 in cpups." ) + if usr_table_proto.HasField("enable_sparse_table_cache"): + table_proto.enable_sparse_table_cache = usr_table_proto.enable_sparse_table_cache + if usr_table_proto.HasField("sparse_table_cache_rate"): + table_proto.sparse_table_cache_rate = usr_table_proto.sparse_table_cache_rate + if usr_table_proto.HasField("sparse_table_cache_file_num"): + table_proto.sparse_table_cache_file_num = usr_table_proto.sparse_table_cache_file_num + if usr_table_proto.HasField("enable_revert"): + table_proto.enable_revert = usr_table_proto.enable_revert + if usr_table_proto.HasField("shard_merge_rate"): + table_proto.shard_merge_rate = usr_table_proto.shard_merge_rate + if usr_table_proto.accessor.ByteSize() == 0: warnings.warn( "The accessor of sparse table is not set, use default value.") @@ -633,16 +695,10 @@ class SparseTable(Table): table_proto.accessor.ParseFromString( usr_table_proto.accessor.SerializeToString()) self.accessor._set(table_proto.accessor, self.common.table_name, - ctx.program_id(), self.context) + ctx.program_id(), self.context, self.common) check_embedding_dim(table_proto.accessor, self.common.table_name, ctx.program_id(), self.context) - self.common.parse_by_optimizer(ctx, self.context) - self.common.parse_entry(self.common.table_name, ctx.program_id(), - self.context) - self.common.sync = True if self.context['is_sync'] else False - - self.common._set(table_proto.common) class GeoSparseTable(SparseTable): @@ -1474,36 +1530,43 @@ class TheOnePSRuntime(RuntimeBase): self._init_params(main_program, scope, send_ctx, dense_map) def _save_one_table(self, table_id, path, mode): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.save_one_model(table_id, path, mode) fleet.util.barrier() def _save_dense_params(self, *args, **kwargs): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._ps_save_dense_params(*args, **kwargs) fleet.util.barrier() def _save_persistables(self, *args, **kwargs): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._save_distributed_persistables(*args, **kwargs) fleet.util.barrier() def _save_inference_model(self, *args, **kwargs): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._ps_inference_save_inference_model(*args, **kwargs) fleet.util.barrier() def _load_one_table(self, table_id, path, mode): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.load_one_table(table_id, path, mode) fleet.util.barrier() def _load_persistables(self, path, mode): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.load_model(path, mode) fleet.util.barrier() def _load_inference_model(self, path, mode): + fleet.util.barrier() if self.role_maker._is_first_worker(): self._ps_inference_load_inference_model(path, mode) fleet.util.barrier() diff --git a/python/paddle/distributed/ps/utils/public.py b/python/paddle/distributed/ps/utils/public.py index 2e3cb1388f8..3b2310f1143 100755 --- a/python/paddle/distributed/ps/utils/public.py +++ b/python/paddle/distributed/ps/utils/public.py @@ -208,6 +208,15 @@ def get_optimize_ops(_program, remote_sparse=[]): return opt_ops +def get_datanorm_ops(_program): + block = _program.global_block() + opt_ops = [] + for op in block.ops: + if op.type == 'data_norm': + opt_ops.append(op) + return opt_ops + + def get_dist_env(): trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0')) trainer_endpoints = '' diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py index 216ea4c2926..b2d9a996b8a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py @@ -93,6 +93,7 @@ class TestPSPassWithBow(unittest.TestCase): # vsum q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') q_ss = fluid.layers.softsign(q_sum) + q_ss = fluid.layers.data_norm(input=q_ss) # fc layer after conv q_fc = fluid.layers.fc( input=q_ss, @@ -183,6 +184,10 @@ class TestPSPassWithBow(unittest.TestCase): configs = {} configs['__emb__'] = { + "table_parameters.__emb__.enable_sparse_table_cache": + True, + "table_parameters.__emb__.shard_merge_rate": + 1, "table_parameters.__emb__.accessor.embed_sgd_param.name": "SparseNaiveSGDRule", "table_parameters.__emb__.accessor.embedx_sgd_param.name": -- GitLab