未验证 提交 a9e0d28c 编写于 作者: W wangguanqun 提交者: GitHub

default accessor and multi table config (#37714)

* default accessor and multi table config

* add unittest

* add unittest

* delete print
上级 f695dc97
......@@ -180,10 +180,11 @@ enum TableType {
message TableParameter {
optional uint64 table_id = 1;
optional string table_class = 2;
optional uint64 shard_num = 3 [ default = 1000 ];
optional TableType type = 4;
optional TableAccessorParameter accessor = 5;
optional string table_name = 2;
optional string table_class = 3;
optional uint64 shard_num = 4 [ default = 1000 ];
optional TableType type = 5;
optional TableAccessorParameter accessor = 6;
}
message TableAccessorParameter {
......@@ -198,7 +199,6 @@ message TableAccessorParameter {
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
}
// TODO(guanqun): add NaiveSGD/Adam...
message SGDParameter {
optional string name = 1;
optional SparseNaiveSGDRuleParameter naive = 2;
......@@ -321,7 +321,7 @@ message DistributedStrategy {
optional HybridConfig hybrid_configs = 112;
optional TensorParallelConfig tensor_parallel_configs = 113;
optional TrainerDescConfig trainer_desc_configs = 114;
optional TableParameter downpour_table_param = 115;
repeated TableParameter downpour_table_param = 115;
optional FsClientParameter fs_client_param = 116;
optional BuildStrategy build_strategy = 201;
......
......@@ -474,12 +474,12 @@ class DistributedStrategy(object):
for field in msg.DESCRIPTOR.fields:
name = config_name + "." + field.name
if field.type == FieldDescriptor.TYPE_MESSAGE:
print("message:", name)
# print("message:", name)
if field.label == FieldDescriptor.LABEL_REPEATED:
if name + ".num" not in configs:
continue
num = configs[name + ".num"]
print("message num:", name, num)
# print("message num:", name, num)
for i in range(num):
data = getattr(msg, field.name).add()
set_table_config(data, name, configs, i)
......@@ -487,7 +487,7 @@ class DistributedStrategy(object):
set_table_config(
getattr(msg, field.name), name, configs)
else:
print("not message:", name)
# print("not message:", name)
if name not in configs:
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
......@@ -501,7 +501,11 @@ class DistributedStrategy(object):
if not configs:
print("table configs is empty")
else:
set_table_config(table_param, "table_parameters", configs)
for table_name in configs:
table_data = table_param.add()
table_data.table_name = table_name
set_table_config(table_data, "table_parameters." + table_name,
configs[table_name])
@property
def amp(self):
......
......@@ -56,53 +56,77 @@ def get_default_accessor_proto(accessor, varname, o_main_program):
embedding_dim = 0
for var in o_main_program.list_vars():
if var.name == varname:
print("var:", var)
print("var.shape:", var.shape)
embedding_dim = var.shape[1]
print("sparse dim:", embedding_dim)
break
if not accessor.HasField("accessor_class"):
accessor.accessor_class = "CtrCommonAccessor"
if not accessor.HasField("fea_dim"):
accessor.fea_dim = embedding_dim + 2
if not accessor.HasField("embedx_dim"):
accessor.embedx_dim = embedding_dim - 1
if not accessor.HasField("embedx_threshold"):
accessor.embedx_threshold = 0
ctr_accessor_param = accessor.ctr_accessor_param
if not ctr_accessor_param.HasField("nonclk_coeff"):
ctr_accessor_param.nonclk_coeff = 0.1
if not ctr_accessor_param.HasField("click_coeff"):
ctr_accessor_param.click_coeff = 1.0
if not ctr_accessor_param.HasField("base_threshold"):
ctr_accessor_param.base_threshold = 0
if not ctr_accessor_param.HasField("delta_threshold"):
ctr_accessor_param.delta_threshold = 0
if not ctr_accessor_param.HasField("delta_keep_days"):
ctr_accessor_param.delta_keep_days = 16
if not ctr_accessor_param.HasField("show_click_decay_rate"):
ctr_accessor_param.show_click_decay_rate = 1
if not ctr_accessor_param.HasField("delete_threshold"):
ctr_accessor_param.delete_threshold = 0
if not ctr_accessor_param.HasField("delete_after_unseen_days"):
ctr_accessor_param.delete_after_unseen_days = 30
if not ctr_accessor_param.HasField("ssd_unseenday_threshold"):
ctr_accessor_param.ssd_unseenday_threshold = 1
embed_sgd_param = accessor.embed_sgd_param
embed_sgd_param.name = "SparseAdaGradSGDRule"
embed_sgd_param.adagrad.learning_rate = 0.05
embed_sgd_param.adagrad.initial_g2sum = 3.0
embed_sgd_param.adagrad.initial_range = 0.0001
embed_sgd_param.adagrad.weight_bounds.append(-10.0)
embed_sgd_param.adagrad.weight_bounds.append(10.0)
embedx_sgd_param = accessor.embedx_sgd_param
embedx_sgd_param.name = "SparseAdaGradSGDRule"
embedx_sgd_param.adagrad.learning_rate = 0.05
embedx_sgd_param.adagrad.initial_g2sum = 3.0
embedx_sgd_param.adagrad.initial_range = 0.0001
embedx_sgd_param.adagrad.weight_bounds.append(-10.0)
embedx_sgd_param.adagrad.weight_bounds.append(10.0)
for sgd_param in [accessor.embed_sgd_param, accessor.embedx_sgd_param]:
if not sgd_param.HasField("name"):
sgd_param.name = "SparseAdaGradSGDRule"
if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule":
if not sgd_param.adagrad.HasField("learning_rate"):
sgd_param.adagrad.learning_rate = 0.05
if not sgd_param.adagrad.HasField("initial_g2sum"):
sgd_param.adagrad.initial_g2sum = 3.0
if not sgd_param.adagrad.HasField("initial_range"):
sgd_param.adagrad.initial_range = 0.0001
if len(sgd_param.adagrad.weight_bounds) == 0:
sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseNaiveSGDRule":
if not sgd_param.naive.HasField("learning_rate"):
sgd_param.naive.learning_rate = 0.05
if not sgd_param.naive.HasField("initial_range"):
sgd_param.naive.initial_range = 0.0001
if len(sgd_param.naive.weight_bounds) == 0:
sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseAdamSGDRule":
if not sgd_param.adam.HasField("learning_rate"):
sgd_param.adam.learning_rate = 0.001
if not sgd_param.adam.HasField("initial_range"):
sgd_param.adam.initial_range = 0.0001
if not sgd_param.adam.HasField("beta1_decay_rate"):
sgd_param.adam.beta1_decay_rate = 0.9
if not sgd_param.adam.HasField("beta2_decay_rate"):
sgd_param.adam.beta2_decay_rate = 0.999
if not sgd_param.adam.HasField("ada_epsilon"):
sgd_param.adam.ada_epsilon = 1e-08
if len(sgd_param.adam.weight_bounds) == 0:
sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program):
embedding_dim = 0
for var in o_main_program.list_vars():
if var.name == varname:
print("var:", var)
print("var.shape:", var.shape)
embedding_dim = var.shape[1]
print("sparse dim:", embedding_dim)
break
fea_dim = accessor.fea_dim
if fea_dim != embedding_dim + 2:
......@@ -917,19 +941,14 @@ class TheOnePSRuntime(RuntimeBase):
if self.compiled_strategy.is_geo_mode():
table.table_class = "SparseGeoTable"
else:
import copy
table_proto = copy.deepcopy(self.context[
"user_defined_strategy"].sparse_table_configs)
print('table proto:', table_proto)
print('table_class:', table_proto.table_class)
print('shard_num:', table_proto.shard_num)
print('table_proto.accessor:', table_proto.accessor)
print('accessor.IsInitialized',
table_proto.accessor.IsInitialized())
print('accessor.ByteSize',
table_proto.accessor.ByteSize())
if table_proto.table_class:
print('table_proto.table_class is true')
all_table_proto = self.context[
"user_defined_strategy"].sparse_table_configs
table_proto = all_table_proto.add()
for proto in all_table_proto:
if proto.table_name == common.table_name:
table_proto = proto
break
if table_proto.HasField("table_class"):
table.table_class = table_proto.table_class
else:
table.table_class = parse_table_class(
......@@ -939,8 +958,7 @@ class TheOnePSRuntime(RuntimeBase):
warnings.warn(
"The PS mode must use MemorySparseTable.")
if table_proto.shard_num:
print('table_proto.shard_num is true')
if table_proto.HasField("shard_num"):
table.shard_num = table_proto.shard_num
else:
table.shard_num = 1000
......@@ -949,22 +967,18 @@ class TheOnePSRuntime(RuntimeBase):
)
if table_proto.accessor.ByteSize() == 0:
print('table_proto.accessor is false')
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
warnings.warn(
"The accessor of sparse table is not set, use default value."
)
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
check_embedding_dim(table_proto.accessor,
common.table_name,
self.origin_main_program)
print('accessor.ByteSize',
table_proto.accessor.ByteSize())
from google.protobuf import text_format
table.accessor_proto = text_format.MessageToString(
table_proto.accessor)
print("the_one_ps table_proto:", table.accessor_proto)
else:
table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable"
......@@ -1275,10 +1289,8 @@ class TheOnePSRuntime(RuntimeBase):
is_dense=False,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
print("the one ps sparses:", sparses)
sparse_names = self._save_sparse_params(executor, dirname, sparses,
main_program, mode)
print("the one ps sparse names:", sparse_names)
denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True,
......@@ -1293,7 +1305,7 @@ class TheOnePSRuntime(RuntimeBase):
filter(
TheOnePSRuntime.__exclude_vars(sparse_names),
infer_program.list_vars()))
print("remain_vars:", [var.name for var in remaining_vars])
for var in remaining_vars:
tensor = var.get_value()
paddle.save(
......
......@@ -167,6 +167,15 @@ class TestPSPassWithBow(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
configs = {}
configs['__emb__'] = {
"table_parameters.__emb__.accessor.embed_sgd_param.name":
"SparseNaiveSGDRule",
"table_parameters.__emb__.accessor.embedx_sgd_param.name":
"SparseAdamSGDRule",
}
strategy.sparse_table_configs = configs
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss)
......
......@@ -257,19 +257,19 @@ class TestStrategyConfig(unittest.TestCase):
def test_sparse_table_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {
"table_parameters.accessor.embed_sgd_param.adagrad.learning_rate":
configs = {}
configs['emb'] = {
"table_parameters.emb.accessor.embed_sgd_param.adagrad.learning_rate":
0.05,
"table_parameters.accessor.table_accessor_save_param.num": 2,
"table_parameters.accessor.table_accessor_save_param.param":
"table_parameters.emb.accessor.table_accessor_save_param.num": 2,
"table_parameters.emb.accessor.table_accessor_save_param.param":
[1, 2]
}
strategy.sparse_table_configs = configs
self.assertEqual(strategy.sparse_table_configs.accessor.embed_sgd_param.
adagrad.learning_rate, 0.05)
self.assertEqual(
strategy.sparse_table_configs.accessor.table_accessor_save_param[
0].param, 1)
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.adagrad.learning_rate, 0.05)
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.table_accessor_save_param[0].param, 1)
strategy.adam_d2sum = True
self.assertEqual(strategy.adam_d2sum, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册