提交 b2f789c6 编写于 作者: J JiabinYang

add test transpiler dist test, test=develop

上级 b5fa9164
...@@ -879,29 +879,36 @@ class TestRemoteNce(TestDistLookupTableBase): ...@@ -879,29 +879,36 @@ class TestRemoteNce(TestDistLookupTableBase):
class TestRemoteHsigmoid(TestDistLookupTableBase): class TestRemoteHsigmoid(TestDistLookupTableBase):
def network_with_table(self, is_sparse, is_distributed): def network_with_table(self, is_sparse, is_distributed):
num_total_classes = 10 num_total_classes = 3
input = fluid.layers.data(name="input", shape=[10], dtype="float32") input = fluid.layers.data(name="input", shape=[1], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
path_table = fluid.layers.data( path_table = fluid.layers.data(
name='path_table', shape=[10], dtype='int64') name='path_table', shape=[3], dtype='int64')
path_code = fluid.layers.data( path_code = fluid.layers.data(
name='path_code', shape=[10], dtype='int64') name='path_code', shape=[3], dtype='int64')
w_param = fluid.default_main_program().global_block().create_parameter( w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10], shape=[num_total_classes, 10],
dtype='float32', dtype='float32',
name='hs_w', name='hs_w',
initializer=fluid.initializer.ConstantInitializer()) initializer=fluid.initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter( b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1], shape=[3, 1],
dtype='float32', dtype='float32',
name='hs_b', name='hs_b',
initializer=fluid.initializer.ConstantInitializer()) initializer=fluid.initializer.ConstantInitializer())
cost = fluid.layers.hsigmoid( emb = fluid.layers.embedding(
input=input, input=input,
is_sparse=is_sparse,
size=[3, 3],
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(num_total_classes))))
cost = fluid.layers.hsigmoid(
input=emb,
label=label, label=label,
num_classes=non_leaf_num, num_classes=num_total_classes,
path_table=path_table, path_table=path_table,
path_code=path_code, path_code=path_code,
is_custom=True, is_custom=True,
...@@ -918,9 +925,29 @@ class TestRemoteHsigmoid(TestDistLookupTableBase): ...@@ -918,9 +925,29 @@ class TestRemoteHsigmoid(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
trainer, _ = self.get_trainer() trainer, _ = self.get_trainer()
params_to_check = list()
for op in trainer.blocks[0].ops: for op in trainer.blocks[0].ops:
if op.type == "recv": if op.type == "hierarchical_sigmoid":
params_to_check = [op.input("W")[0], op.input("Bias")[0]]
for name in ["epmap", "table_names", "epmap"]:
assert op.has_attr(name)
if name == "epmap":
assert op.attr(name)[0] == u'127.0.0.1:6174'
elif name == "table_names":
assert op.attr(name)[0] == u'hierarchical_sigmoid_0.w_0'
else:
assert op.attr(name) == 3
elif op.type == "lookup_table":
params_to_check.append(op.input("W")[0])
else:
pass pass
op_count = 0
for op in trainer.blocks[0].ops:
if op.type == "recv":
assert len(op.output("Out")) == 1
assert op.output("Out")[0] == u'hierarchical_sigmoid_0.b_0'
op_count += 1
assert op_count == 1
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -242,7 +242,7 @@ class DistributeTranspiler(object): ...@@ -242,7 +242,7 @@ class DistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program): def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = [] sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce"] sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"]
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr( if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True: 'remote_prefetch') is True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册