未验证 提交 ae217373 编写于 作者: W wangguanqun 提交者: GitHub

ps optimizer default config (#45563)

* config

* fix unittest

* zero init & cache & patch config

* add barrier to save and load

* add unittest
上级 ae542dc7
...@@ -177,8 +177,10 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) { ...@@ -177,8 +177,10 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) {
value[common_feature_value.ShowIndex()] = 0; value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0; value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1; value[common_feature_value.SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(), _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex()); value + common_feature_value.EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(), value + common_feature_value.EmbedxG2SumIndex(),
false); false);
......
...@@ -176,9 +176,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) { ...@@ -176,9 +176,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0; *reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0; *(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[CtrDoubleFeatureValue::SlotIndex()] = -1; value[CtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->InitValue( bool zero_init = _config.ctr_accessor_param().zero_init();
value + CtrDoubleFeatureValue::EmbedWIndex(), _embed_sgd_rule->InitValue(value + CtrDoubleFeatureValue::EmbedWIndex(),
value + CtrDoubleFeatureValue::EmbedG2SumIndex()); value + CtrDoubleFeatureValue::EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue( _embedx_sgd_rule->InitValue(
value + CtrDoubleFeatureValue::EmbedxWIndex(), value + CtrDoubleFeatureValue::EmbedxWIndex(),
value + CtrDoubleFeatureValue::EmbedxG2SumIndex(), value + CtrDoubleFeatureValue::EmbedxG2SumIndex(),
......
...@@ -150,8 +150,10 @@ int32_t SparseAccessor::Create(float** values, size_t num) { ...@@ -150,8 +150,10 @@ int32_t SparseAccessor::Create(float** values, size_t num) {
value[sparse_feature_value.ShowIndex()] = 0; value[sparse_feature_value.ShowIndex()] = 0;
value[sparse_feature_value.ClickIndex()] = 0; value[sparse_feature_value.ClickIndex()] = 0;
value[sparse_feature_value.SlotIndex()] = -1; value[sparse_feature_value.SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(), _embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(),
value + sparse_feature_value.EmbedG2SumIndex()); value + sparse_feature_value.EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(), _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
value + sparse_feature_value.EmbedxG2SumIndex(), value + sparse_feature_value.EmbedxG2SumIndex(),
false); false);
......
...@@ -120,6 +120,7 @@ message TableParameter { ...@@ -120,6 +120,7 @@ message TableParameter {
optional bool enable_sparse_table_cache = 10 [ default = true ]; optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
// for patch model
optional bool enable_revert = 13 [ default = false ]; optional bool enable_revert = 13 [ default = false ];
optional float shard_merge_rate = 14 [ default = 1.0 ]; optional float shard_merge_rate = 14 [ default = 1.0 ];
} }
...@@ -167,6 +168,7 @@ message CtrAccessorParameter { ...@@ -167,6 +168,7 @@ message CtrAccessorParameter {
optional int32 ssd_unseenday_threshold = 9 optional int32 ssd_unseenday_threshold = 9
[ default = 1 ]; // threshold to save ssd [ default = 1 ]; // threshold to save ssd
optional bool show_scale = 10 [ default = true ]; optional bool show_scale = 10 [ default = true ];
optional bool zero_init = 11 [ default = true ];
} }
message TensorAccessorParameter { message TensorAccessorParameter {
...@@ -189,6 +191,7 @@ message CommonAccessorParameter { ...@@ -189,6 +191,7 @@ message CommonAccessorParameter {
optional bool sync = 9; optional bool sync = 9;
optional uint32 table_num = 10; optional uint32 table_num = 10;
optional uint32 table_dim = 11; optional uint32 table_dim = 11;
optional string attr = 12;
} }
message TableAccessorSaveParameter { message TableAccessorSaveParameter {
......
...@@ -205,6 +205,13 @@ message TableParameter { ...@@ -205,6 +205,13 @@ message TableParameter {
optional TableType type = 5; optional TableType type = 5;
optional TableAccessorParameter accessor = 6; optional TableAccessorParameter accessor = 6;
optional bool compress_in_save = 7 [ default = false ]; optional bool compress_in_save = 7 [ default = false ];
// for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
// for patch model
optional bool enable_revert = 13 [ default = false ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
} }
message TableAccessorParameter { message TableAccessorParameter {
......
...@@ -627,6 +627,12 @@ class DistributedStrategy(object): ...@@ -627,6 +627,12 @@ class DistributedStrategy(object):
% (table_class)) % (table_class))
table_data.table_class = 'MemorySparseTable' table_data.table_class = 'MemorySparseTable'
table_data.shard_num = config.get('sparse_shard_num', 1000) table_data.shard_num = config.get('sparse_shard_num', 1000)
table_data.enable_sparse_table_cache = config.get(
'sparse_enable_cache', True)
table_data.sparse_table_cache_rate = config.get(
'sparse_cache_rate', 0.00055)
table_data.sparse_table_cache_file_num = config.get(
'sparse_cache_file_num', 16)
accessor_class = config.get("sparse_accessor_class", accessor_class = config.get("sparse_accessor_class",
"DownpourCtrAccessor") "DownpourCtrAccessor")
......
...@@ -127,7 +127,8 @@ class Accessor: ...@@ -127,7 +127,8 @@ class Accessor:
self.embedding_dim = 0 self.embedding_dim = 0
# TableAccessorParameter accessor # TableAccessorParameter accessor
def _set(self, accessor_proto, varname, program_id, context): def _set(self, accessor_proto, varname, program_id, context,
common_accessor):
main_program, startup_program, idx = get_program_by_id( main_program, startup_program, idx = get_program_by_id(
context, program_id) context, program_id)
embedding_dim = 0 embedding_dim = 0
...@@ -162,6 +163,8 @@ class Accessor: ...@@ -162,6 +163,8 @@ class Accessor:
graph_sgd_param.feature_learning_rate = 0.05 graph_sgd_param.feature_learning_rate = 0.05
ctr_accessor_param = accessor_proto.ctr_accessor_param ctr_accessor_param = accessor_proto.ctr_accessor_param
if accessor_proto.embedx_dim == 0:
ctr_accessor_param.zero_init = False
if not ctr_accessor_param.HasField("nonclk_coeff"): if not ctr_accessor_param.HasField("nonclk_coeff"):
ctr_accessor_param.nonclk_coeff = 0.1 ctr_accessor_param.nonclk_coeff = 0.1
if not ctr_accessor_param.HasField("click_coeff"): if not ctr_accessor_param.HasField("click_coeff"):
...@@ -185,7 +188,11 @@ class Accessor: ...@@ -185,7 +188,11 @@ class Accessor:
accessor_proto.embed_sgd_param, accessor_proto.embedx_sgd_param accessor_proto.embed_sgd_param, accessor_proto.embedx_sgd_param
]: ]:
if not sgd_param.HasField("name"): if not sgd_param.HasField("name"):
sgd_param.name = "SparseAdaGradSGDRule" if common_accessor.accessor_class == "sgd":
sgd_param.name = "SparseNaiveSGDRule"
if common_accessor.accessor_class == "adam":
sgd_param.name = "SparseAdamSGDRule"
if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule": if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule":
if not sgd_param.adagrad.HasField("learning_rate"): if not sgd_param.adagrad.HasField("learning_rate"):
sgd_param.adagrad.learning_rate = 0.05 sgd_param.adagrad.learning_rate = 0.05
...@@ -195,23 +202,47 @@ class Accessor: ...@@ -195,23 +202,47 @@ class Accessor:
sgd_param.adagrad.initial_range = 0.0001 sgd_param.adagrad.initial_range = 0.0001
if len(sgd_param.adagrad.weight_bounds) == 0: if len(sgd_param.adagrad.weight_bounds) == 0:
sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0]) sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseNaiveSGDRule": if sgd_param.name == "SparseNaiveSGDRule":
if not sgd_param.naive.HasField("learning_rate"): if not sgd_param.naive.HasField("learning_rate"):
sgd_param.naive.learning_rate = 0.05 learning_rate = common_accessor.initializers[-1].split(
"&")[1]
sgd_param.naive.learning_rate = float(learning_rate)
if not sgd_param.naive.HasField("initial_range"): if not sgd_param.naive.HasField("initial_range"):
sgd_param.naive.initial_range = 0.0001 initial_range = common_accessor.initializers[0].split(
"&")[-1]
sgd_param.naive.initial_range = float(initial_range)
if len(sgd_param.naive.weight_bounds) == 0: if len(sgd_param.naive.weight_bounds) == 0:
sgd_param.naive.weight_bounds.extend([-10.0, 10.0]) sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule": if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule":
if not sgd_param.adam.HasField("learning_rate"): if not sgd_param.adam.HasField("learning_rate"):
sgd_param.adam.learning_rate = 0.001 learning_rate = common_accessor.initializers[-1].split(
"&")[1]
sgd_param.adam.learning_rate = float(learning_rate)
if not sgd_param.adam.HasField("initial_range"): if not sgd_param.adam.HasField("initial_range"):
sgd_param.adam.initial_range = 0.0001 initial_range = common_accessor.initializers[0].split(
if not sgd_param.adam.HasField("beta1_decay_rate"): "&")[-1]
sgd_param.adam.initial_range = float(initial_range)
attr_list = [x.split("&") for x in common_accessor.attrs]
if not sgd_param.adam.HasField(
"beta1_decay_rate"
) and common_accessor.accessor_class == "adam":
sgd_param.adam.beta1_decay_rate = float(attr_list[0][1])
else:
sgd_param.adam.beta1_decay_rate = 0.9 sgd_param.adam.beta1_decay_rate = 0.9
if not sgd_param.adam.HasField("beta2_decay_rate"): if not sgd_param.adam.HasField(
"beta2_decay_rate"
) and common_accessor.accessor_class == "adam":
sgd_param.adam.beta2_decay_rate = float(attr_list[1][1])
else:
sgd_param.adam.beta2_decay_rate = 0.999 sgd_param.adam.beta2_decay_rate = 0.999
if not sgd_param.adam.HasField("ada_epsilon"): if not sgd_param.adam.HasField(
"ada_epsilon"
) and common_accessor.accessor_class == "adam":
sgd_param.adam.ada_epsilon = float(attr_list[2][1])
else:
sgd_param.adam.ada_epsilon = 1e-08 sgd_param.adam.ada_epsilon = 1e-08
if len(sgd_param.adam.weight_bounds) == 0: if len(sgd_param.adam.weight_bounds) == 0:
sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
...@@ -258,7 +289,7 @@ class CommonAccessor(Accessor): ...@@ -258,7 +289,7 @@ class CommonAccessor(Accessor):
("epsilon", "f")] ("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")] ("epsilon", "f")]
opt_attr_map["summary"] = [] opt_attr_map["summary"] = [("summary_decay_rate", "f")]
opt_init_map = {} opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"] opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
...@@ -375,6 +406,9 @@ class CommonAccessor(Accessor): ...@@ -375,6 +406,9 @@ class CommonAccessor(Accessor):
attr_varnames = self.opt_attr_map["adam_d2sum"] attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum" self.accessor_class = "adam_d2sum"
else: else:
if oop.type != 'sgd' and oop.type != 'adam':
raise ValueError(
"The dense optimizer in PS is only supported SGD or Adam!")
param_varnames = self.opt_input_map[oop.type] param_varnames = self.opt_input_map[oop.type]
attr_varnames = self.opt_attr_map[oop.type] attr_varnames = self.opt_attr_map[oop.type]
self.accessor_class = oop.type self.accessor_class = oop.type
...@@ -459,9 +493,18 @@ class CommonAccessor(Accessor): ...@@ -459,9 +493,18 @@ class CommonAccessor(Accessor):
param.name, startup_program) param.name, startup_program)
initializers.append(initializer) initializers.append(initializer)
if self.accessor_class == 'summary':
datanorm_ops = get_datanorm_ops(main_program)
for op in datanorm_ops:
if ("BatchSize" in op.input_names) and (
op.input("BatchSize")[0]
== context['grad_name_to_param_name'][grad_name]):
oop = op
break
for (attr_varname, type_) in attr_varnames: for (attr_varname, type_) in attr_varnames:
value = oop.attr(attr_varname) value = oop.attr(attr_varname)
attrs.append("&".join([attr_varname, type_, str(value)])) attrs.append("&".join([attr_varname, str(value)]))
self.params = params self.params = params
self.dims = dims self.dims = dims
...@@ -480,6 +523,7 @@ class CommonAccessor(Accessor): ...@@ -480,6 +523,7 @@ class CommonAccessor(Accessor):
proto.sync = self.sync proto.sync = self.sync
proto.table_num = self.table_num proto.table_num = self.table_num
proto.table_dim = self.table_dim proto.table_dim = self.table_dim
proto.attr = "#".join(self.attrs)
class Tensor: class Tensor:
...@@ -599,6 +643,13 @@ class SparseTable(Table): ...@@ -599,6 +643,13 @@ class SparseTable(Table):
self.common.table_name = self.context['grad_name_to_param_name'][ self.common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]] ctx.origin_varnames()[0]]
self.common.parse_by_optimizer(ctx, self.context)
self.common.parse_entry(self.common.table_name, ctx.program_id(),
self.context)
self.common.sync = True if self.context['is_sync'] else False
self.common._set(table_proto.common)
print('new table_name: {}'.format(self.common.table_name)) print('new table_name: {}'.format(self.common.table_name))
all_table_proto = self.context[ all_table_proto = self.context[
"user_defined_strategy"].sparse_table_configs "user_defined_strategy"].sparse_table_configs
...@@ -626,6 +677,17 @@ class SparseTable(Table): ...@@ -626,6 +677,17 @@ class SparseTable(Table):
"The shard_num of sparse table is not set, use default value 1000 in cpups." "The shard_num of sparse table is not set, use default value 1000 in cpups."
) )
if usr_table_proto.HasField("enable_sparse_table_cache"):
table_proto.enable_sparse_table_cache = usr_table_proto.enable_sparse_table_cache
if usr_table_proto.HasField("sparse_table_cache_rate"):
table_proto.sparse_table_cache_rate = usr_table_proto.sparse_table_cache_rate
if usr_table_proto.HasField("sparse_table_cache_file_num"):
table_proto.sparse_table_cache_file_num = usr_table_proto.sparse_table_cache_file_num
if usr_table_proto.HasField("enable_revert"):
table_proto.enable_revert = usr_table_proto.enable_revert
if usr_table_proto.HasField("shard_merge_rate"):
table_proto.shard_merge_rate = usr_table_proto.shard_merge_rate
if usr_table_proto.accessor.ByteSize() == 0: if usr_table_proto.accessor.ByteSize() == 0:
warnings.warn( warnings.warn(
"The accessor of sparse table is not set, use default value.") "The accessor of sparse table is not set, use default value.")
...@@ -633,16 +695,10 @@ class SparseTable(Table): ...@@ -633,16 +695,10 @@ class SparseTable(Table):
table_proto.accessor.ParseFromString( table_proto.accessor.ParseFromString(
usr_table_proto.accessor.SerializeToString()) usr_table_proto.accessor.SerializeToString())
self.accessor._set(table_proto.accessor, self.common.table_name, self.accessor._set(table_proto.accessor, self.common.table_name,
ctx.program_id(), self.context) ctx.program_id(), self.context, self.common)
check_embedding_dim(table_proto.accessor, self.common.table_name, check_embedding_dim(table_proto.accessor, self.common.table_name,
ctx.program_id(), self.context) ctx.program_id(), self.context)
self.common.parse_by_optimizer(ctx, self.context)
self.common.parse_entry(self.common.table_name, ctx.program_id(),
self.context)
self.common.sync = True if self.context['is_sync'] else False
self.common._set(table_proto.common)
class GeoSparseTable(SparseTable): class GeoSparseTable(SparseTable):
...@@ -1474,36 +1530,43 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1474,36 +1530,43 @@ class TheOnePSRuntime(RuntimeBase):
self._init_params(main_program, scope, send_ctx, dense_map) self._init_params(main_program, scope, send_ctx, dense_map)
def _save_one_table(self, table_id, path, mode): def _save_one_table(self, table_id, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._worker.save_one_model(table_id, path, mode) self._worker.save_one_model(table_id, path, mode)
fleet.util.barrier() fleet.util.barrier()
def _save_dense_params(self, *args, **kwargs): def _save_dense_params(self, *args, **kwargs):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._ps_save_dense_params(*args, **kwargs) self._ps_save_dense_params(*args, **kwargs)
fleet.util.barrier() fleet.util.barrier()
def _save_persistables(self, *args, **kwargs): def _save_persistables(self, *args, **kwargs):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._save_distributed_persistables(*args, **kwargs) self._save_distributed_persistables(*args, **kwargs)
fleet.util.barrier() fleet.util.barrier()
def _save_inference_model(self, *args, **kwargs): def _save_inference_model(self, *args, **kwargs):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._ps_inference_save_inference_model(*args, **kwargs) self._ps_inference_save_inference_model(*args, **kwargs)
fleet.util.barrier() fleet.util.barrier()
def _load_one_table(self, table_id, path, mode): def _load_one_table(self, table_id, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._worker.load_one_table(table_id, path, mode) self._worker.load_one_table(table_id, path, mode)
fleet.util.barrier() fleet.util.barrier()
def _load_persistables(self, path, mode): def _load_persistables(self, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._worker.load_model(path, mode) self._worker.load_model(path, mode)
fleet.util.barrier() fleet.util.barrier()
def _load_inference_model(self, path, mode): def _load_inference_model(self, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker(): if self.role_maker._is_first_worker():
self._ps_inference_load_inference_model(path, mode) self._ps_inference_load_inference_model(path, mode)
fleet.util.barrier() fleet.util.barrier()
......
...@@ -208,6 +208,15 @@ def get_optimize_ops(_program, remote_sparse=[]): ...@@ -208,6 +208,15 @@ def get_optimize_ops(_program, remote_sparse=[]):
return opt_ops return opt_ops
def get_datanorm_ops(_program):
block = _program.global_block()
opt_ops = []
for op in block.ops:
if op.type == 'data_norm':
opt_ops.append(op)
return opt_ops
def get_dist_env(): def get_dist_env():
trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0')) trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0'))
trainer_endpoints = '' trainer_endpoints = ''
......
...@@ -93,6 +93,7 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -93,6 +93,7 @@ class TestPSPassWithBow(unittest.TestCase):
# vsum # vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
q_ss = fluid.layers.softsign(q_sum) q_ss = fluid.layers.softsign(q_sum)
q_ss = fluid.layers.data_norm(input=q_ss)
# fc layer after conv # fc layer after conv
q_fc = fluid.layers.fc( q_fc = fluid.layers.fc(
input=q_ss, input=q_ss,
...@@ -183,6 +184,10 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -183,6 +184,10 @@ class TestPSPassWithBow(unittest.TestCase):
configs = {} configs = {}
configs['__emb__'] = { configs['__emb__'] = {
"table_parameters.__emb__.enable_sparse_table_cache":
True,
"table_parameters.__emb__.shard_merge_rate":
1,
"table_parameters.__emb__.accessor.embed_sgd_param.name": "table_parameters.__emb__.accessor.embed_sgd_param.name":
"SparseNaiveSGDRule", "SparseNaiveSGDRule",
"table_parameters.__emb__.accessor.embedx_sgd_param.name": "table_parameters.__emb__.accessor.embedx_sgd_param.name":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册