未验证 提交 a9e0d28c 编写于 作者: W wangguanqun 提交者: GitHub

default accessor and multi table config (#37714)

* default accessor and multi table config

* add unittest

* add unittest

* delete print
上级 f695dc97
...@@ -180,10 +180,11 @@ enum TableType { ...@@ -180,10 +180,11 @@ enum TableType {
message TableParameter { message TableParameter {
optional uint64 table_id = 1; optional uint64 table_id = 1;
optional string table_class = 2; optional string table_name = 2;
optional uint64 shard_num = 3 [ default = 1000 ]; optional string table_class = 3;
optional TableType type = 4; optional uint64 shard_num = 4 [ default = 1000 ];
optional TableAccessorParameter accessor = 5; optional TableType type = 5;
optional TableAccessorParameter accessor = 6;
} }
message TableAccessorParameter { message TableAccessorParameter {
...@@ -198,7 +199,6 @@ message TableAccessorParameter { ...@@ -198,7 +199,6 @@ message TableAccessorParameter {
repeated TableAccessorSaveParameter table_accessor_save_param = 8; repeated TableAccessorSaveParameter table_accessor_save_param = 8;
} }
// TODO(guanqun): add NaiveSGD/Adam...
message SGDParameter { message SGDParameter {
optional string name = 1; optional string name = 1;
optional SparseNaiveSGDRuleParameter naive = 2; optional SparseNaiveSGDRuleParameter naive = 2;
...@@ -321,7 +321,7 @@ message DistributedStrategy { ...@@ -321,7 +321,7 @@ message DistributedStrategy {
optional HybridConfig hybrid_configs = 112; optional HybridConfig hybrid_configs = 112;
optional TensorParallelConfig tensor_parallel_configs = 113; optional TensorParallelConfig tensor_parallel_configs = 113;
optional TrainerDescConfig trainer_desc_configs = 114; optional TrainerDescConfig trainer_desc_configs = 114;
optional TableParameter downpour_table_param = 115; repeated TableParameter downpour_table_param = 115;
optional FsClientParameter fs_client_param = 116; optional FsClientParameter fs_client_param = 116;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
......
...@@ -474,12 +474,12 @@ class DistributedStrategy(object): ...@@ -474,12 +474,12 @@ class DistributedStrategy(object):
for field in msg.DESCRIPTOR.fields: for field in msg.DESCRIPTOR.fields:
name = config_name + "." + field.name name = config_name + "." + field.name
if field.type == FieldDescriptor.TYPE_MESSAGE: if field.type == FieldDescriptor.TYPE_MESSAGE:
print("message:", name) # print("message:", name)
if field.label == FieldDescriptor.LABEL_REPEATED: if field.label == FieldDescriptor.LABEL_REPEATED:
if name + ".num" not in configs: if name + ".num" not in configs:
continue continue
num = configs[name + ".num"] num = configs[name + ".num"]
print("message num:", name, num) # print("message num:", name, num)
for i in range(num): for i in range(num):
data = getattr(msg, field.name).add() data = getattr(msg, field.name).add()
set_table_config(data, name, configs, i) set_table_config(data, name, configs, i)
...@@ -487,7 +487,7 @@ class DistributedStrategy(object): ...@@ -487,7 +487,7 @@ class DistributedStrategy(object):
set_table_config( set_table_config(
getattr(msg, field.name), name, configs) getattr(msg, field.name), name, configs)
else: else:
print("not message:", name) # print("not message:", name)
if name not in configs: if name not in configs:
continue continue
if field.label == FieldDescriptor.LABEL_REPEATED: if field.label == FieldDescriptor.LABEL_REPEATED:
...@@ -501,7 +501,11 @@ class DistributedStrategy(object): ...@@ -501,7 +501,11 @@ class DistributedStrategy(object):
if not configs: if not configs:
print("table configs is empty") print("table configs is empty")
else: else:
set_table_config(table_param, "table_parameters", configs) for table_name in configs:
table_data = table_param.add()
table_data.table_name = table_name
set_table_config(table_data, "table_parameters." + table_name,
configs[table_name])
@property @property
def amp(self): def amp(self):
......
...@@ -56,53 +56,77 @@ def get_default_accessor_proto(accessor, varname, o_main_program): ...@@ -56,53 +56,77 @@ def get_default_accessor_proto(accessor, varname, o_main_program):
embedding_dim = 0 embedding_dim = 0
for var in o_main_program.list_vars(): for var in o_main_program.list_vars():
if var.name == varname: if var.name == varname:
print("var:", var)
print("var.shape:", var.shape)
embedding_dim = var.shape[1] embedding_dim = var.shape[1]
print("sparse dim:", embedding_dim)
break break
if not accessor.HasField("accessor_class"):
accessor.accessor_class = "CtrCommonAccessor" accessor.accessor_class = "CtrCommonAccessor"
if not accessor.HasField("fea_dim"):
accessor.fea_dim = embedding_dim + 2 accessor.fea_dim = embedding_dim + 2
if not accessor.HasField("embedx_dim"):
accessor.embedx_dim = embedding_dim - 1 accessor.embedx_dim = embedding_dim - 1
if not accessor.HasField("embedx_threshold"):
accessor.embedx_threshold = 0 accessor.embedx_threshold = 0
ctr_accessor_param = accessor.ctr_accessor_param ctr_accessor_param = accessor.ctr_accessor_param
if not ctr_accessor_param.HasField("nonclk_coeff"):
ctr_accessor_param.nonclk_coeff = 0.1 ctr_accessor_param.nonclk_coeff = 0.1
if not ctr_accessor_param.HasField("click_coeff"):
ctr_accessor_param.click_coeff = 1.0 ctr_accessor_param.click_coeff = 1.0
if not ctr_accessor_param.HasField("base_threshold"):
ctr_accessor_param.base_threshold = 0 ctr_accessor_param.base_threshold = 0
if not ctr_accessor_param.HasField("delta_threshold"):
ctr_accessor_param.delta_threshold = 0 ctr_accessor_param.delta_threshold = 0
if not ctr_accessor_param.HasField("delta_keep_days"):
ctr_accessor_param.delta_keep_days = 16 ctr_accessor_param.delta_keep_days = 16
if not ctr_accessor_param.HasField("show_click_decay_rate"):
ctr_accessor_param.show_click_decay_rate = 1 ctr_accessor_param.show_click_decay_rate = 1
if not ctr_accessor_param.HasField("delete_threshold"):
ctr_accessor_param.delete_threshold = 0 ctr_accessor_param.delete_threshold = 0
if not ctr_accessor_param.HasField("delete_after_unseen_days"):
ctr_accessor_param.delete_after_unseen_days = 30 ctr_accessor_param.delete_after_unseen_days = 30
if not ctr_accessor_param.HasField("ssd_unseenday_threshold"):
ctr_accessor_param.ssd_unseenday_threshold = 1 ctr_accessor_param.ssd_unseenday_threshold = 1
embed_sgd_param = accessor.embed_sgd_param for sgd_param in [accessor.embed_sgd_param, accessor.embedx_sgd_param]:
embed_sgd_param.name = "SparseAdaGradSGDRule" if not sgd_param.HasField("name"):
embed_sgd_param.adagrad.learning_rate = 0.05 sgd_param.name = "SparseAdaGradSGDRule"
embed_sgd_param.adagrad.initial_g2sum = 3.0 if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule":
embed_sgd_param.adagrad.initial_range = 0.0001 if not sgd_param.adagrad.HasField("learning_rate"):
embed_sgd_param.adagrad.weight_bounds.append(-10.0) sgd_param.adagrad.learning_rate = 0.05
embed_sgd_param.adagrad.weight_bounds.append(10.0) if not sgd_param.adagrad.HasField("initial_g2sum"):
sgd_param.adagrad.initial_g2sum = 3.0
embedx_sgd_param = accessor.embedx_sgd_param if not sgd_param.adagrad.HasField("initial_range"):
embedx_sgd_param.name = "SparseAdaGradSGDRule" sgd_param.adagrad.initial_range = 0.0001
embedx_sgd_param.adagrad.learning_rate = 0.05 if len(sgd_param.adagrad.weight_bounds) == 0:
embedx_sgd_param.adagrad.initial_g2sum = 3.0 sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
embedx_sgd_param.adagrad.initial_range = 0.0001 if sgd_param.name == "SparseNaiveSGDRule":
embedx_sgd_param.adagrad.weight_bounds.append(-10.0) if not sgd_param.naive.HasField("learning_rate"):
embedx_sgd_param.adagrad.weight_bounds.append(10.0) sgd_param.naive.learning_rate = 0.05
if not sgd_param.naive.HasField("initial_range"):
sgd_param.naive.initial_range = 0.0001
if len(sgd_param.naive.weight_bounds) == 0:
sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseAdamSGDRule":
if not sgd_param.adam.HasField("learning_rate"):
sgd_param.adam.learning_rate = 0.001
if not sgd_param.adam.HasField("initial_range"):
sgd_param.adam.initial_range = 0.0001
if not sgd_param.adam.HasField("beta1_decay_rate"):
sgd_param.adam.beta1_decay_rate = 0.9
if not sgd_param.adam.HasField("beta2_decay_rate"):
sgd_param.adam.beta2_decay_rate = 0.999
if not sgd_param.adam.HasField("ada_epsilon"):
sgd_param.adam.ada_epsilon = 1e-08
if len(sgd_param.adam.weight_bounds) == 0:
sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program): def check_embedding_dim(accessor, varname, o_main_program):
embedding_dim = 0 embedding_dim = 0
for var in o_main_program.list_vars(): for var in o_main_program.list_vars():
if var.name == varname: if var.name == varname:
print("var:", var)
print("var.shape:", var.shape)
embedding_dim = var.shape[1] embedding_dim = var.shape[1]
print("sparse dim:", embedding_dim)
break break
fea_dim = accessor.fea_dim fea_dim = accessor.fea_dim
if fea_dim != embedding_dim + 2: if fea_dim != embedding_dim + 2:
...@@ -917,19 +941,14 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -917,19 +941,14 @@ class TheOnePSRuntime(RuntimeBase):
if self.compiled_strategy.is_geo_mode(): if self.compiled_strategy.is_geo_mode():
table.table_class = "SparseGeoTable" table.table_class = "SparseGeoTable"
else: else:
import copy all_table_proto = self.context[
table_proto = copy.deepcopy(self.context[ "user_defined_strategy"].sparse_table_configs
"user_defined_strategy"].sparse_table_configs) table_proto = all_table_proto.add()
print('table proto:', table_proto) for proto in all_table_proto:
print('table_class:', table_proto.table_class) if proto.table_name == common.table_name:
print('shard_num:', table_proto.shard_num) table_proto = proto
print('table_proto.accessor:', table_proto.accessor) break
print('accessor.IsInitialized', if table_proto.HasField("table_class"):
table_proto.accessor.IsInitialized())
print('accessor.ByteSize',
table_proto.accessor.ByteSize())
if table_proto.table_class:
print('table_proto.table_class is true')
table.table_class = table_proto.table_class table.table_class = table_proto.table_class
else: else:
table.table_class = parse_table_class( table.table_class = parse_table_class(
...@@ -939,8 +958,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -939,8 +958,7 @@ class TheOnePSRuntime(RuntimeBase):
warnings.warn( warnings.warn(
"The PS mode must use MemorySparseTable.") "The PS mode must use MemorySparseTable.")
if table_proto.shard_num: if table_proto.HasField("shard_num"):
print('table_proto.shard_num is true')
table.shard_num = table_proto.shard_num table.shard_num = table_proto.shard_num
else: else:
table.shard_num = 1000 table.shard_num = 1000
...@@ -949,22 +967,18 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -949,22 +967,18 @@ class TheOnePSRuntime(RuntimeBase):
) )
if table_proto.accessor.ByteSize() == 0: if table_proto.accessor.ByteSize() == 0:
print('table_proto.accessor is false')
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
warnings.warn( warnings.warn(
"The accessor of sparse table is not set, use default value." "The accessor of sparse table is not set, use default value."
) )
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
check_embedding_dim(table_proto.accessor, check_embedding_dim(table_proto.accessor,
common.table_name, common.table_name,
self.origin_main_program) self.origin_main_program)
print('accessor.ByteSize',
table_proto.accessor.ByteSize())
from google.protobuf import text_format from google.protobuf import text_format
table.accessor_proto = text_format.MessageToString( table.accessor_proto = text_format.MessageToString(
table_proto.accessor) table_proto.accessor)
print("the_one_ps table_proto:", table.accessor_proto)
else: else:
table.type = "PS_DENSE_TABLE" table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable" table.table_class = "CommonDenseTable"
...@@ -1275,10 +1289,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1275,10 +1289,8 @@ class TheOnePSRuntime(RuntimeBase):
is_dense=False, is_dense=False,
split_dense_table=self.role_maker._is_heter_parameter_server_mode, split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True) use_origin_program=True)
print("the one ps sparses:", sparses)
sparse_names = self._save_sparse_params(executor, dirname, sparses, sparse_names = self._save_sparse_params(executor, dirname, sparses,
main_program, mode) main_program, mode)
print("the one ps sparse names:", sparse_names)
denses = self.compiled_strategy.get_the_one_recv_context( denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True, is_dense=True,
...@@ -1293,7 +1305,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1293,7 +1305,7 @@ class TheOnePSRuntime(RuntimeBase):
filter( filter(
TheOnePSRuntime.__exclude_vars(sparse_names), TheOnePSRuntime.__exclude_vars(sparse_names),
infer_program.list_vars())) infer_program.list_vars()))
print("remain_vars:", [var.name for var in remaining_vars])
for var in remaining_vars: for var in remaining_vars:
tensor = var.get_value() tensor = var.get_value()
paddle.save( paddle.save(
......
...@@ -167,6 +167,15 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -167,6 +167,15 @@ class TestPSPassWithBow(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True strategy.a_sync = True
configs = {}
configs['__emb__'] = {
"table_parameters.__emb__.accessor.embed_sgd_param.name":
"SparseNaiveSGDRule",
"table_parameters.__emb__.accessor.embedx_sgd_param.name":
"SparseAdamSGDRule",
}
strategy.sparse_table_configs = configs
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss) optimizer.minimize(loss)
......
...@@ -257,19 +257,19 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -257,19 +257,19 @@ class TestStrategyConfig(unittest.TestCase):
def test_sparse_table_configs(self): def test_sparse_table_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
configs = { configs = {}
"table_parameters.accessor.embed_sgd_param.adagrad.learning_rate": configs['emb'] = {
"table_parameters.emb.accessor.embed_sgd_param.adagrad.learning_rate":
0.05, 0.05,
"table_parameters.accessor.table_accessor_save_param.num": 2, "table_parameters.emb.accessor.table_accessor_save_param.num": 2,
"table_parameters.accessor.table_accessor_save_param.param": "table_parameters.emb.accessor.table_accessor_save_param.param":
[1, 2] [1, 2]
} }
strategy.sparse_table_configs = configs strategy.sparse_table_configs = configs
self.assertEqual(strategy.sparse_table_configs.accessor.embed_sgd_param. self.assertEqual(strategy.sparse_table_configs[0]
adagrad.learning_rate, 0.05) .accessor.embed_sgd_param.adagrad.learning_rate, 0.05)
self.assertEqual( self.assertEqual(strategy.sparse_table_configs[0]
strategy.sparse_table_configs.accessor.table_accessor_save_param[ .accessor.table_accessor_save_param[0].param, 1)
0].param, 1)
strategy.adam_d2sum = True strategy.adam_d2sum = True
self.assertEqual(strategy.adam_d2sum, True) self.assertEqual(strategy.adam_d2sum, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册