提交 2b6c0c09 编写于 作者: Q Qiao Longfei

add unit test

上级 cc6ef41d
...@@ -89,6 +89,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -89,6 +89,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false); .SetDefault(false);
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int64_t>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
......
...@@ -285,6 +285,7 @@ def embedding(input, ...@@ -285,6 +285,7 @@ def embedding(input,
size, size,
is_sparse=False, is_sparse=False,
is_distributed=False, is_distributed=False,
remote_prefetch=False,
padding_idx=None, padding_idx=None,
param_attr=None, param_attr=None,
dtype='float32'): dtype='float32'):
...@@ -326,6 +327,8 @@ def embedding(input, ...@@ -326,6 +327,8 @@ def embedding(input,
""" """
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
if remote_prefetch:
assert is_sparse is True and is_distributed is False
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False) attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_variable_for_type_inference(dtype) tmp = helper.create_variable_for_type_inference(dtype)
...@@ -339,6 +342,7 @@ def embedding(input, ...@@ -339,6 +342,7 @@ def embedding(input,
attrs={ attrs={
'is_sparse': is_sparse, 'is_sparse': is_sparse,
'is_distributed': is_distributed, 'is_distributed': is_distributed,
'remote_prefetch': remote_prefetch,
'padding_idx': padding_idx 'padding_idx': padding_idx
}) })
return tmp return tmp
......
...@@ -447,19 +447,23 @@ class TestEmptyPserverOptimizeBlocks(TranspilerTest): ...@@ -447,19 +447,23 @@ class TestEmptyPserverOptimizeBlocks(TranspilerTest):
class TestDistLookupTableBase(TranspilerTest): class TestDistLookupTableBase(TranspilerTest):
def network_with_table(self, is_sparse, is_distributed): def network_with_table(self,
is_sparse,
is_distributed,
remote_prefetch=False):
self.table_size = 1000 self.table_size = 1000
self.emb_size = 64 self.emb_size = 64
self.lookup_table_name = 'shared_w' self.lookup_table_name = 'shared_w'
def emb_pool(ids, table_name, is_distributed): def emb_pool(ids, table_name, is_distributed, remote_prefetch):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=ids, input=ids,
size=[self.table_size, self.emb_size], size=[self.table_size, self.emb_size],
dtype='float32', dtype='float32',
param_attr=table_name, param_attr=table_name,
is_sparse=is_sparse, is_sparse=is_sparse,
is_distributed=is_distributed) is_distributed=is_distributed,
remote_prefetch=remote_prefetch)
pool = fluid.layers.sequence_pool(input=emb, pool_type='average') pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
return pool return pool
...@@ -469,9 +473,12 @@ class TestDistLookupTableBase(TranspilerTest): ...@@ -469,9 +473,12 @@ class TestDistLookupTableBase(TranspilerTest):
name='brand_ids', shape=[1], dtype='int64', lod_level=1) name='brand_ids', shape=[1], dtype='int64', lod_level=1)
profile_ids = fluid.layers.data( profile_ids = fluid.layers.data(
name='brand_ids', shape=[1], dtype='int64', lod_level=1) name='brand_ids', shape=[1], dtype='int64', lod_level=1)
title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed) title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed,
brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed) False)
profile_emb = emb_pool(profile_ids, "profile_emb", False) brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed,
False)
profile_emb = emb_pool(profile_ids, "profile_emb", False,
remote_prefetch)
fc0 = fluid.layers.concat( fc0 = fluid.layers.concat(
input=[title_emb, brand_emb, profile_emb], axis=1) input=[title_emb, brand_emb, profile_emb], axis=1)
predict = fluid.layers.fc(input=fc0, predict = fluid.layers.fc(input=fc0,
...@@ -575,6 +582,57 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -575,6 +582,57 @@ class TestDistLookupTable(TestDistLookupTableBase):
startup_ops) startup_ops)
class TestRemoteLookupTable(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(
is_sparse=True, is_distributed=False, remote_prefetch=True)
def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 6)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["sum", "scale", "adam", "scale", "scale"])
# 4 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "scale", "adam", "scale", "scale"])
# 2 optimize for table sgd
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["sum", "sgd"])
# 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
["lookup_sparse_table"])
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, trainer_startup = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1)
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'lookup_table_grad', 'split_selected_rows', 'send',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
'lookup_table_grad', 'sum', 'split_ids', 'send', 'send_barrier',
'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestAsyncLocalLookupTable(TestDistLookupTableBase): class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def net_conf(self): def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=False) self.network_with_table(is_sparse=True, is_distributed=False)
...@@ -782,5 +840,45 @@ class TestNCCL2Transpile(TranspilerTest): ...@@ -782,5 +840,45 @@ class TestNCCL2Transpile(TranspilerTest):
pass pass
# test for remote prefetch
class TestRemoteLookupTable(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(
is_sparse=True, is_distributed=False, remote_prefetch=True)
def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 4)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["sum", "scale", "adam", "scale", "scale"])
# 2 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "scale", "adam", "scale", "scale"])
# 3 optimize for table 2 adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["sum", "scale", "adam", "scale", "scale"])
trainer, _ = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1)
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
'recv', 'recv', 'fetch_barrier', 'concat'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -236,21 +236,29 @@ class DistributeTranspiler(object): ...@@ -236,21 +236,29 @@ class DistributeTranspiler(object):
else: else:
raise ValueError("must set trainer_id > 0") raise ValueError("must set trainer_id > 0")
def _get_all_sparse_update_op(self, main_program): def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = [] sparse_update_ops = []
sparse_update_op_types = ["lookup_table"] sparse_update_op_types = ["lookup_table"]
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr( if op.type in sparse_update_op_types and op.attr(
'is_sparse') is True and not op.attr('is_distributed'): 'remote_prefetch') is True and not op.attr(
'is_distributed'):
sparse_update_ops.append(op) sparse_update_ops.append(op)
return sparse_update_ops return sparse_update_ops
def _update_sparse_update_op(self, param_varname, height_sections, def _update_remote_sparse_update_op(self, param_varname, height_sections,
endpint_map): endpint_map):
for op in self.sparse_update_ops: for op in self.sparse_update_ops:
if param_varname in op.input_arg_names: if param_varname in op.input_arg_names:
op._set_attr('epmap', endpint_map) op._set_attr('epmap', endpint_map)
op._set_attr('height_sections', height_sections) op._set_attr('height_sections', height_sections)
op._set_attr('trainer_id', self.trainer_id)
def _is_input_of_remote_sparse_update_op(self, param_name):
for op in self.sparse_update_ops:
if param_name in op.input_arg_names:
return True
return False
def transpile(self, def transpile(self,
trainer_id, trainer_id,
...@@ -316,7 +324,7 @@ class DistributeTranspiler(object): ...@@ -316,7 +324,7 @@ class DistributeTranspiler(object):
self.grad_name_to_param_name[grad_var.name] = param_var.name self.grad_name_to_param_name[grad_var.name] = param_var.name
# get all sparse update ops # get all sparse update ops
self.sparse_update_ops = self._get_all_sparse_update_op( self.sparse_update_ops = self._get_all_remote_sparse_update_op(
self.origin_program) self.origin_program)
self.sparse_param_to_height_sections = dict() self.sparse_param_to_height_sections = dict()
...@@ -449,8 +457,8 @@ class DistributeTranspiler(object): ...@@ -449,8 +457,8 @@ class DistributeTranspiler(object):
if param_varname in self.sparse_param_to_height_sections: if param_varname in self.sparse_param_to_height_sections:
height_sections = self.sparse_param_to_height_sections[ height_sections = self.sparse_param_to_height_sections[
param_varname] param_varname]
self._update_sparse_update_op(param_varname, height_sections, self._update_remote_sparse_update_op(param_varname,
eps) height_sections, eps)
else: else:
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
...@@ -481,8 +489,6 @@ class DistributeTranspiler(object): ...@@ -481,8 +489,6 @@ class DistributeTranspiler(object):
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
print("sparse_param_to_height_sections: " + str(
self.sparse_param_to_height_sections))
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
program.global_block().append_op( program.global_block().append_op(
type="concat", type="concat",
...@@ -1448,7 +1454,7 @@ to transpile() call.") ...@@ -1448,7 +1454,7 @@ to transpile() call.")
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
sparse_param_name = self.grad_name_to_param_name[orig_var.name] sparse_param_name = self.grad_name_to_param_name[orig_var.name]
if sparse_param_name != self.table_name: if self._is_input_of_remote_sparse_update_op(sparse_param_name):
self.sparse_param_to_height_sections[ self.sparse_param_to_height_sections[
sparse_param_name] = height_sections sparse_param_name] = height_sections
program.global_block()._insert_op( program.global_block()._insert_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册