提交 b2f789c6 编写于 作者: J JiabinYang

add test transpiler dist test, test=develop

上级 b5fa9164
......@@ -879,29 +879,36 @@ class TestRemoteNce(TestDistLookupTableBase):
class TestRemoteHsigmoid(TestDistLookupTableBase):
def network_with_table(self, is_sparse, is_distributed):
num_total_classes = 10
num_total_classes = 3
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
input = fluid.layers.data(name="input", shape=[1], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
path_table = fluid.layers.data(
name='path_table', shape=[10], dtype='int64')
name='path_table', shape=[3], dtype='int64')
path_code = fluid.layers.data(
name='path_code', shape=[10], dtype='int64')
name='path_code', shape=[3], dtype='int64')
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10],
dtype='float32',
name='hs_w',
initializer=fluid.initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1],
shape=[3, 1],
dtype='float32',
name='hs_b',
initializer=fluid.initializer.ConstantInitializer())
cost = fluid.layers.hsigmoid(
emb = fluid.layers.embedding(
input=input,
is_sparse=is_sparse,
size=[3, 3],
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(num_total_classes))))
cost = fluid.layers.hsigmoid(
input=emb,
label=label,
num_classes=non_leaf_num,
num_classes=num_total_classes,
path_table=path_table,
path_code=path_code,
is_custom=True,
......@@ -918,9 +925,29 @@ class TestRemoteHsigmoid(TestDistLookupTableBase):
def transpiler_test_impl(self):
trainer, _ = self.get_trainer()
params_to_check = list()
for op in trainer.blocks[0].ops:
if op.type == "recv":
if op.type == "hierarchical_sigmoid":
params_to_check = [op.input("W")[0], op.input("Bias")[0]]
for name in ["epmap", "table_names", "epmap"]:
assert op.has_attr(name)
if name == "epmap":
assert op.attr(name)[0] == u'127.0.0.1:6174'
elif name == "table_names":
assert op.attr(name)[0] == u'hierarchical_sigmoid_0.w_0'
else:
assert op.attr(name) == 3
elif op.type == "lookup_table":
params_to_check.append(op.input("W")[0])
else:
pass
op_count = 0
for op in trainer.blocks[0].ops:
if op.type == "recv":
assert len(op.output("Out")) == 1
assert op.output("Out")[0] == u'hierarchical_sigmoid_0.b_0'
op_count += 1
assert op_count == 1
if __name__ == "__main__":
......
......@@ -242,7 +242,7 @@ class DistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce"]
sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册