diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 27575897b547e8b257fae89452645d6f1fca6162..f572d6927783a41595e4c6944fbf2b5f385f88ec 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -879,29 +879,36 @@ class TestRemoteNce(TestDistLookupTableBase): class TestRemoteHsigmoid(TestDistLookupTableBase): def network_with_table(self, is_sparse, is_distributed): - num_total_classes = 10 + num_total_classes = 3 - input = fluid.layers.data(name="input", shape=[10], dtype="float32") + input = fluid.layers.data(name="input", shape=[1], dtype="float32") label = fluid.layers.data(name="label", shape=[1], dtype="int64") path_table = fluid.layers.data( - name='path_table', shape=[10], dtype='int64') + name='path_table', shape=[3], dtype='int64') path_code = fluid.layers.data( - name='path_code', shape=[10], dtype='int64') + name='path_code', shape=[3], dtype='int64') w_param = fluid.default_main_program().global_block().create_parameter( shape=[num_total_classes, 10], dtype='float32', name='hs_w', initializer=fluid.initializer.ConstantInitializer()) b_param = fluid.default_main_program().global_block().create_parameter( - shape=[num_total_classes, 1], + shape=[3, 1], dtype='float32', name='hs_b', initializer=fluid.initializer.ConstantInitializer()) - cost = fluid.layers.hsigmoid( + emb = fluid.layers.embedding( input=input, + is_sparse=is_sparse, + size=[3, 3], + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(num_total_classes)))) + + cost = fluid.layers.hsigmoid( + input=emb, label=label, - num_classes=non_leaf_num, + num_classes=num_total_classes, path_table=path_table, path_code=path_code, is_custom=True, @@ -918,9 +925,29 @@ class TestRemoteHsigmoid(TestDistLookupTableBase): def transpiler_test_impl(self): trainer, _ = self.get_trainer() + params_to_check = list() for op in trainer.blocks[0].ops: - if op.type == "recv": + if op.type == "hierarchical_sigmoid": + params_to_check = [op.input("W")[0], op.input("Bias")[0]] + for name in ["epmap", "table_names", "epmap"]: + assert op.has_attr(name) + if name == "epmap": + assert op.attr(name)[0] == u'127.0.0.1:6174' + elif name == "table_names": + assert op.attr(name)[0] == u'hierarchical_sigmoid_0.w_0' + else: + assert op.attr(name) == 3 + elif op.type == "lookup_table": + params_to_check.append(op.input("W")[0]) + else: pass + op_count = 0 + for op in trainer.blocks[0].ops: + if op.type == "recv": + assert len(op.output("Out")) == 1 + assert op.output("Out")[0] == u'hierarchical_sigmoid_0.b_0' + op_count += 1 + assert op_count == 1 if __name__ == "__main__": diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 378654ab5b1514f7799ef899b99831a9c5cc4e76..f5ca3dffb737e8650141c1a4761a9722230d6bff 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -242,7 +242,7 @@ class DistributeTranspiler(object): def _get_all_remote_sparse_update_op(self, main_program): sparse_update_ops = [] - sparse_update_op_types = ["lookup_table", "nce"] + sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] for op in main_program.global_block().ops: if op.type in sparse_update_op_types and op.attr( 'remote_prefetch') is True: