未验证 提交 ae217373 编写于 作者: W wangguanqun 提交者: GitHub

ps optimizer default config (#45563)

* config

* fix unittest

* zero init & cache & patch config

* add barrier to save and load

* add unittest
上级 ae542dc7
......@@ -177,8 +177,10 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) {
value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
value + common_feature_value.EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
......
......@@ -176,9 +176,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[CtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->InitValue(
value + CtrDoubleFeatureValue::EmbedWIndex(),
value + CtrDoubleFeatureValue::EmbedG2SumIndex());
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + CtrDoubleFeatureValue::EmbedWIndex(),
value + CtrDoubleFeatureValue::EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue(
value + CtrDoubleFeatureValue::EmbedxWIndex(),
value + CtrDoubleFeatureValue::EmbedxG2SumIndex(),
......
......@@ -150,8 +150,10 @@ int32_t SparseAccessor::Create(float** values, size_t num) {
value[sparse_feature_value.ShowIndex()] = 0;
value[sparse_feature_value.ClickIndex()] = 0;
value[sparse_feature_value.SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(),
value + sparse_feature_value.EmbedG2SumIndex());
value + sparse_feature_value.EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(),
value + sparse_feature_value.EmbedxG2SumIndex(),
false);
......
......@@ -120,6 +120,7 @@ message TableParameter {
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
// for patch model
optional bool enable_revert = 13 [ default = false ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
}
......@@ -167,6 +168,7 @@ message CtrAccessorParameter {
optional int32 ssd_unseenday_threshold = 9
[ default = 1 ]; // threshold to save ssd
optional bool show_scale = 10 [ default = true ];
optional bool zero_init = 11 [ default = true ];
}
message TensorAccessorParameter {
......@@ -189,6 +191,7 @@ message CommonAccessorParameter {
optional bool sync = 9;
optional uint32 table_num = 10;
optional uint32 table_dim = 11;
optional string attr = 12;
}
message TableAccessorSaveParameter {
......
......@@ -205,6 +205,13 @@ message TableParameter {
optional TableType type = 5;
optional TableAccessorParameter accessor = 6;
optional bool compress_in_save = 7 [ default = false ];
// for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
// for patch model
optional bool enable_revert = 13 [ default = false ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
}
message TableAccessorParameter {
......
......@@ -627,6 +627,12 @@ class DistributedStrategy(object):
% (table_class))
table_data.table_class = 'MemorySparseTable'
table_data.shard_num = config.get('sparse_shard_num', 1000)
table_data.enable_sparse_table_cache = config.get(
'sparse_enable_cache', True)
table_data.sparse_table_cache_rate = config.get(
'sparse_cache_rate', 0.00055)
table_data.sparse_table_cache_file_num = config.get(
'sparse_cache_file_num', 16)
accessor_class = config.get("sparse_accessor_class",
"DownpourCtrAccessor")
......
......@@ -127,7 +127,8 @@ class Accessor:
self.embedding_dim = 0
# TableAccessorParameter accessor
def _set(self, accessor_proto, varname, program_id, context):
def _set(self, accessor_proto, varname, program_id, context,
common_accessor):
main_program, startup_program, idx = get_program_by_id(
context, program_id)
embedding_dim = 0
......@@ -162,6 +163,8 @@ class Accessor:
graph_sgd_param.feature_learning_rate = 0.05
ctr_accessor_param = accessor_proto.ctr_accessor_param
if accessor_proto.embedx_dim == 0:
ctr_accessor_param.zero_init = False
if not ctr_accessor_param.HasField("nonclk_coeff"):
ctr_accessor_param.nonclk_coeff = 0.1
if not ctr_accessor_param.HasField("click_coeff"):
......@@ -185,7 +188,11 @@ class Accessor:
accessor_proto.embed_sgd_param, accessor_proto.embedx_sgd_param
]:
if not sgd_param.HasField("name"):
sgd_param.name = "SparseAdaGradSGDRule"
if common_accessor.accessor_class == "sgd":
sgd_param.name = "SparseNaiveSGDRule"
if common_accessor.accessor_class == "adam":
sgd_param.name = "SparseAdamSGDRule"
if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule":
if not sgd_param.adagrad.HasField("learning_rate"):
sgd_param.adagrad.learning_rate = 0.05
......@@ -195,23 +202,47 @@ class Accessor:
sgd_param.adagrad.initial_range = 0.0001
if len(sgd_param.adagrad.weight_bounds) == 0:
sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseNaiveSGDRule":
if not sgd_param.naive.HasField("learning_rate"):
sgd_param.naive.learning_rate = 0.05
learning_rate = common_accessor.initializers[-1].split(
"&")[1]
sgd_param.naive.learning_rate = float(learning_rate)
if not sgd_param.naive.HasField("initial_range"):
sgd_param.naive.initial_range = 0.0001
initial_range = common_accessor.initializers[0].split(
"&")[-1]
sgd_param.naive.initial_range = float(initial_range)
if len(sgd_param.naive.weight_bounds) == 0:
sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule":
if not sgd_param.adam.HasField("learning_rate"):
sgd_param.adam.learning_rate = 0.001
learning_rate = common_accessor.initializers[-1].split(
"&")[1]
sgd_param.adam.learning_rate = float(learning_rate)
if not sgd_param.adam.HasField("initial_range"):
sgd_param.adam.initial_range = 0.0001
if not sgd_param.adam.HasField("beta1_decay_rate"):
initial_range = common_accessor.initializers[0].split(
"&")[-1]
sgd_param.adam.initial_range = float(initial_range)
attr_list = [x.split("&") for x in common_accessor.attrs]
if not sgd_param.adam.HasField(
"beta1_decay_rate"
) and common_accessor.accessor_class == "adam":
sgd_param.adam.beta1_decay_rate = float(attr_list[0][1])
else:
sgd_param.adam.beta1_decay_rate = 0.9
if not sgd_param.adam.HasField("beta2_decay_rate"):
if not sgd_param.adam.HasField(
"beta2_decay_rate"
) and common_accessor.accessor_class == "adam":
sgd_param.adam.beta2_decay_rate = float(attr_list[1][1])
else:
sgd_param.adam.beta2_decay_rate = 0.999
if not sgd_param.adam.HasField("ada_epsilon"):
if not sgd_param.adam.HasField(
"ada_epsilon"
) and common_accessor.accessor_class == "adam":
sgd_param.adam.ada_epsilon = float(attr_list[2][1])
else:
sgd_param.adam.ada_epsilon = 1e-08
if len(sgd_param.adam.weight_bounds) == 0:
sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
......@@ -258,7 +289,7 @@ class CommonAccessor(Accessor):
("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_attr_map["summary"] = []
opt_attr_map["summary"] = [("summary_decay_rate", "f")]
opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
......@@ -375,6 +406,9 @@ class CommonAccessor(Accessor):
attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum"
else:
if oop.type != 'sgd' and oop.type != 'adam':
raise ValueError(
"The dense optimizer in PS is only supported SGD or Adam!")
param_varnames = self.opt_input_map[oop.type]
attr_varnames = self.opt_attr_map[oop.type]
self.accessor_class = oop.type
......@@ -459,9 +493,18 @@ class CommonAccessor(Accessor):
param.name, startup_program)
initializers.append(initializer)
if self.accessor_class == 'summary':
datanorm_ops = get_datanorm_ops(main_program)
for op in datanorm_ops:
if ("BatchSize" in op.input_names) and (
op.input("BatchSize")[0]
== context['grad_name_to_param_name'][grad_name]):
oop = op
break
for (attr_varname, type_) in attr_varnames:
value = oop.attr(attr_varname)
attrs.append("&".join([attr_varname, type_, str(value)]))
attrs.append("&".join([attr_varname, str(value)]))
self.params = params
self.dims = dims
......@@ -480,6 +523,7 @@ class CommonAccessor(Accessor):
proto.sync = self.sync
proto.table_num = self.table_num
proto.table_dim = self.table_dim
proto.attr = "#".join(self.attrs)
class Tensor:
......@@ -599,6 +643,13 @@ class SparseTable(Table):
self.common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]]
self.common.parse_by_optimizer(ctx, self.context)
self.common.parse_entry(self.common.table_name, ctx.program_id(),
self.context)
self.common.sync = True if self.context['is_sync'] else False
self.common._set(table_proto.common)
print('new table_name: {}'.format(self.common.table_name))
all_table_proto = self.context[
"user_defined_strategy"].sparse_table_configs
......@@ -626,6 +677,17 @@ class SparseTable(Table):
"The shard_num of sparse table is not set, use default value 1000 in cpups."
)
if usr_table_proto.HasField("enable_sparse_table_cache"):
table_proto.enable_sparse_table_cache = usr_table_proto.enable_sparse_table_cache
if usr_table_proto.HasField("sparse_table_cache_rate"):
table_proto.sparse_table_cache_rate = usr_table_proto.sparse_table_cache_rate
if usr_table_proto.HasField("sparse_table_cache_file_num"):
table_proto.sparse_table_cache_file_num = usr_table_proto.sparse_table_cache_file_num
if usr_table_proto.HasField("enable_revert"):
table_proto.enable_revert = usr_table_proto.enable_revert
if usr_table_proto.HasField("shard_merge_rate"):
table_proto.shard_merge_rate = usr_table_proto.shard_merge_rate
if usr_table_proto.accessor.ByteSize() == 0:
warnings.warn(
"The accessor of sparse table is not set, use default value.")
......@@ -633,16 +695,10 @@ class SparseTable(Table):
table_proto.accessor.ParseFromString(
usr_table_proto.accessor.SerializeToString())
self.accessor._set(table_proto.accessor, self.common.table_name,
ctx.program_id(), self.context)
ctx.program_id(), self.context, self.common)
check_embedding_dim(table_proto.accessor, self.common.table_name,
ctx.program_id(), self.context)
self.common.parse_by_optimizer(ctx, self.context)
self.common.parse_entry(self.common.table_name, ctx.program_id(),
self.context)
self.common.sync = True if self.context['is_sync'] else False
self.common._set(table_proto.common)
class GeoSparseTable(SparseTable):
......@@ -1474,36 +1530,43 @@ class TheOnePSRuntime(RuntimeBase):
self._init_params(main_program, scope, send_ctx, dense_map)
def _save_one_table(self, table_id, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.save_one_model(table_id, path, mode)
fleet.util.barrier()
def _save_dense_params(self, *args, **kwargs):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._ps_save_dense_params(*args, **kwargs)
fleet.util.barrier()
def _save_persistables(self, *args, **kwargs):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._save_distributed_persistables(*args, **kwargs)
fleet.util.barrier()
def _save_inference_model(self, *args, **kwargs):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._ps_inference_save_inference_model(*args, **kwargs)
fleet.util.barrier()
def _load_one_table(self, table_id, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.load_one_table(table_id, path, mode)
fleet.util.barrier()
def _load_persistables(self, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.load_model(path, mode)
fleet.util.barrier()
def _load_inference_model(self, path, mode):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._ps_inference_load_inference_model(path, mode)
fleet.util.barrier()
......
......@@ -208,6 +208,15 @@ def get_optimize_ops(_program, remote_sparse=[]):
return opt_ops
def get_datanorm_ops(_program):
block = _program.global_block()
opt_ops = []
for op in block.ops:
if op.type == 'data_norm':
opt_ops.append(op)
return opt_ops
def get_dist_env():
trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0'))
trainer_endpoints = ''
......
......@@ -93,6 +93,7 @@ class TestPSPassWithBow(unittest.TestCase):
# vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
q_ss = fluid.layers.softsign(q_sum)
q_ss = fluid.layers.data_norm(input=q_ss)
# fc layer after conv
q_fc = fluid.layers.fc(
input=q_ss,
......@@ -183,6 +184,10 @@ class TestPSPassWithBow(unittest.TestCase):
configs = {}
configs['__emb__'] = {
"table_parameters.__emb__.enable_sparse_table_cache":
True,
"table_parameters.__emb__.shard_merge_rate":
1,
"table_parameters.__emb__.accessor.embed_sgd_param.name":
"SparseNaiveSGDRule",
"table_parameters.__emb__.accessor.embedx_sgd_param.name":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册