未验证 提交 7825ae9c 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #14190 from jacquesqiao/dist-table-support-multi-table

Dist table support multi table
...@@ -411,12 +411,12 @@ class TestDistLookupTableBase(TranspilerTest): ...@@ -411,12 +411,12 @@ class TestDistLookupTableBase(TranspilerTest):
self.emb_size = 64 self.emb_size = 64
self.lookup_table_name = 'shared_w' self.lookup_table_name = 'shared_w'
def emb_pool(ids): def emb_pool(ids, table_name, is_distributed):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=ids, input=ids,
size=[self.table_size, self.emb_size], size=[self.table_size, self.emb_size],
dtype='float32', dtype='float32',
param_attr=self.lookup_table_name, # share parameter param_attr=table_name,
is_sparse=is_sparse, is_sparse=is_sparse,
is_distributed=is_distributed) is_distributed=is_distributed)
pool = fluid.layers.sequence_pool(input=emb, pool_type='average') pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
...@@ -426,9 +426,13 @@ class TestDistLookupTableBase(TranspilerTest): ...@@ -426,9 +426,13 @@ class TestDistLookupTableBase(TranspilerTest):
name='title_ids', shape=[1], dtype='int64', lod_level=1) name='title_ids', shape=[1], dtype='int64', lod_level=1)
brand_ids = fluid.layers.data( brand_ids = fluid.layers.data(
name='brand_ids', shape=[1], dtype='int64', lod_level=1) name='brand_ids', shape=[1], dtype='int64', lod_level=1)
title_emb = emb_pool(title_ids) profile_ids = fluid.layers.data(
brand_emb = emb_pool(brand_ids) name='brand_ids', shape=[1], dtype='int64', lod_level=1)
fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1) title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed)
brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed)
profile_emb = emb_pool(profile_ids, "profile_emb", False)
fc0 = fluid.layers.concat(
input=[title_emb, brand_emb, profile_emb], axis=1)
predict = fluid.layers.fc(input=fc0, predict = fluid.layers.fc(input=fc0,
size=2, size=2,
act=None, act=None,
...@@ -449,7 +453,7 @@ class TestLocalLookupTable(TestDistLookupTableBase): ...@@ -449,7 +453,7 @@ class TestLocalLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep) pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 3) self.assertEqual(len(pserver1.blocks), 4)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
...@@ -459,16 +463,23 @@ class TestLocalLookupTable(TestDistLookupTableBase): ...@@ -459,16 +463,23 @@ class TestLocalLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in pserver1.blocks[2].ops], self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "scale", "adam", "scale", "scale"]) ["sum", "scale", "adam", "scale", "scale"])
# 3 optimize for table 2 adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["sum", "scale", "adam", "scale", "scale"])
trainer, _ = self.get_trainer() trainer, _ = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'fill_constant', 'mean_grad', 'cross_entropy_grad', 'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'split_selected_rows', 'send', 'sequence_pool_grad',
'send_barrier', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat' 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'concat'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
...@@ -480,39 +491,45 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -480,39 +491,45 @@ class TestDistLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep) pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 5) self.assertEqual(len(pserver1.blocks), 6)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["sum", "scale", "adam", "scale", "scale"]) ["sum", "scale", "adam", "scale", "scale"])
# 2 optimize for table sgd # 4 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[2].ops], self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "scale", "adam", "scale", "scale"])
# 2 optimize for table sgd
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["sum", "sgd"]) ["sum", "sgd"])
# 3 prefetch -> lookup_sparse_table for data0 # 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[3].ops], self.assertEqual([op.type for op in pserver1.blocks[4].ops],
["lookup_sparse_table"]) ["lookup_sparse_table"])
# 4 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, trainer_startup = self.get_trainer() trainer, trainer_startup = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'concat', 'mul', 'elementwise_add', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad', 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids', 'lookup_table_grad', 'split_selected_rows', 'send',
'send', 'send_barrier', 'recv', 'recv', 'fetch_barrier' 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
'lookup_table_grad', 'sum', 'split_ids', 'send', 'send_barrier',
'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [ startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fetch_barrier', 'fake_init' 'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
] ]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops) startup_ops)
...@@ -526,7 +543,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ...@@ -526,7 +543,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 3) self.assertEqual(len(pserver1.blocks), 4)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
...@@ -535,17 +552,23 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ...@@ -535,17 +552,23 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[2].ops], self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["adam", "scale", "scale"]) ["adam", "scale", "scale"])
# 3 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["adam", "scale", "scale"])
trainer, _ = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'fill_constant', 'mean_grad', 'cross_entropy_grad', 'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv', 'split_selected_rows', 'send', 'sequence_pool_grad',
'recv', 'recv', 'concat' 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_selected_rows', 'send', 'recv', 'recv', 'recv',
'recv', 'concat', 'concat'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
...@@ -559,29 +582,34 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -559,29 +582,34 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 5) self.assertEqual(len(pserver1.blocks), 6)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["adam", "scale", "scale"]) ["adam", "scale", "scale"])
# 2 optimize for table sgd # 2 optimize for table adam
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"]) self.assertEqual([op.type for op in pserver1.blocks[2].ops],
# 3 prefetch -> lookup_sparse_table for data0 ["adam", "scale", "scale"])
self.assertEqual([op.type for op in pserver1.blocks[3].ops], # 3 optimize for table sgd
self.assertEqual([op.type for op in pserver1.blocks[3].ops], ["sgd"])
# 4 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
["lookup_sparse_table"]) ["lookup_sparse_table"])
# 4 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'concat', 'mul', 'elementwise_add', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad', 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids', 'lookup_table_grad', 'split_selected_rows', 'send',
'send', 'recv', 'recv' 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
'lookup_table_grad', 'sum', 'split_ids', 'send', 'recv', 'recv',
'recv', 'concat'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
......
...@@ -1065,7 +1065,12 @@ to transpile() call.") ...@@ -1065,7 +1065,12 @@ to transpile() call.")
continue_search_lookup_table_op = False continue_search_lookup_table_op = False
all_ops = program.global_block().ops all_ops = program.global_block().ops
for op in all_ops: for op in all_ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE and self.table_name == op.input(
"W")[0]:
if not op.attr('is_distributed'):
raise RuntimeError(
"lookup_table_op that lookup an distributed embedding table"
"should set is_distributed to true")
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list( lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册