提交 373f6498 编写于 作者: Q Qiao Longfei

add comment and unit test

test=develop
上级 67050468
...@@ -249,6 +249,15 @@ class Optimizer(object): ...@@ -249,6 +249,15 @@ class Optimizer(object):
def _process_distribute_lookuptable(self, param_grads, loss, def _process_distribute_lookuptable(self, param_grads, loss,
startup_program): startup_program):
"""
Because distribute lookup table only support SGD optimizer for now, not support
other optimizer and regularization, so we should find the table parameter out,
and avoid to add regularization and other op for it, and add sgd optimize op
for it independently.
:param param_grads(list((Var, Var))): list of (param, grad) pair.
:param loss: the loss variable.
:param startup_program: the startup program
"""
program = loss.block.program program = loss.block.program
table_name = find_distributed_lookup_table(program) table_name = find_distributed_lookup_table(program)
table_param = None table_param = None
......
...@@ -641,7 +641,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -641,7 +641,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 5 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config) trainer, trainer_startup = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
...@@ -655,6 +655,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -655,6 +655,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
'recv', 'concat' 'recv', 'concat'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestDistLookupTableSliceSize(TestDistLookupTableBase): class TestDistLookupTableSliceSize(TestDistLookupTableBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册