test_dist_tree_index.py 5.7 KB
Newer Older
1
123malin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os
import tempfile
1
123malin 已提交
17
import unittest
18 19

import paddle
20
from paddle.dataset.common import download
1
123malin 已提交
21
from paddle.distributed.fleet.dataset import TreeIndex
22

W
wangzhen38 已提交
23 24 25 26
paddle.enable_static()


def create_feeds():
G
GGBond8488 已提交
27 28
    user_input = paddle.static.data(
        name="item_id", shape=[-1, 1], dtype="int64", lod_level=1
29 30
    )

G
GGBond8488 已提交
31 32
    item = paddle.static.data(
        name="unit_id", shape=[-1, 1], dtype="int64", lod_level=1
33 34
    )

G
GGBond8488 已提交
35 36
    label = paddle.static.data(
        name="label", shape=[-1, 1], dtype="int64", lod_level=1
37
    )
G
GGBond8488 已提交
38 39
    labels = paddle.static.data(
        name="labels", shape=[-1, 1], dtype="int64", lod_level=1
40
    )
W
wangzhen38 已提交
41 42 43

    feed_list = [user_input, item, label, labels]
    return feed_list
1
123malin 已提交
44 45 46 47 48


class TestTreeIndex(unittest.TestCase):
    def test_tree_index(self):
        path = download(
W
wangzhen38 已提交
49
            "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb",
50 51 52
            "tree_index_unittest",
            "e2ba4561c2e9432b532df40546390efa",
        )
W
wangzhen38 已提交
53 54 55
        '''
        path = download(
            "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb",
1
123malin 已提交
56
            "tree_index_unittest", "cadec20089f5a8a44d320e117d9f9f1a")
W
wangzhen38 已提交
57
        '''
1
123malin 已提交
58 59 60
        tree = TreeIndex("demo", path)
        height = tree.height()
        branch = tree.branch()
W
wangzhen38 已提交
61
        self.assertTrue(height == 5)
1
123malin 已提交
62
        self.assertTrue(branch == 2)
W
wangzhen38 已提交
63 64
        self.assertEqual(tree.total_node_nums(), 25)
        self.assertEqual(tree.emb_size(), 30)
1
123malin 已提交
65 66 67 68 69 70 71

        # get_layer_codes
        layer_node_ids = []
        layer_node_codes = []
        for i in range(tree.height()):
            layer_node_codes.append(tree.get_layer_codes(i))
            layer_node_ids.append(
72 73
                [node.id() for node in tree.get_nodes(layer_node_codes[-1])]
            )
1
123malin 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

        all_leaf_ids = [node.id() for node in tree.get_all_leafs()]
        self.assertEqual(sum(all_leaf_ids), sum(layer_node_ids[-1]))

        # get_travel
        travel_codes = tree.get_travel_codes(all_leaf_ids[0])
        travel_ids = [node.id() for node in tree.get_nodes(travel_codes)]

        for i in range(height):
            self.assertIn(travel_ids[i], layer_node_ids[height - 1 - i])
            self.assertIn(travel_codes[i], layer_node_codes[height - 1 - i])

        # get_ancestor
        ancestor_codes = tree.get_ancestor_codes([all_leaf_ids[0]], height - 2)
        ancestor_ids = [node.id() for node in tree.get_nodes(ancestor_codes)]

        self.assertEqual(ancestor_ids[0], travel_ids[1])
        self.assertEqual(ancestor_codes[0], travel_codes[1])

        # get_pi_relation
        pi_relation = tree.get_pi_relation([all_leaf_ids[0]], height - 2)
        self.assertEqual(pi_relation[all_leaf_ids[0]], ancestor_codes[0])

        # get_travel_path
98 99 100
        travel_path_codes = tree.get_travel_path(
            travel_codes[0], travel_codes[-1]
        )
1
123malin 已提交
101 102 103 104
        travel_path_ids = [
            node.id() for node in tree.get_nodes(travel_path_codes)
        ]

105 106
        self.assertEqual(travel_path_ids + [travel_ids[-1]], travel_ids)
        self.assertEqual(travel_path_codes + [travel_codes[-1]], travel_codes)
1
123malin 已提交
107 108 109 110 111 112 113 114

        # get_children
        children_codes = tree.get_children_codes(travel_codes[1], height - 1)
        children_ids = [node.id() for node in tree.get_nodes(children_codes)]
        self.assertIn(all_leaf_ids[0], children_ids)


class TestIndexSampler(unittest.TestCase):
115 116 117 118 119 120
    def setUp(self):
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDown(self):
        self.temp_dir.cleanup()

1
123malin 已提交
121 122
    def test_layerwise_sampler(self):
        path = download(
W
wangzhen38 已提交
123
            "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb",
124 125 126
            "tree_index_unittest",
            "e2ba4561c2e9432b532df40546390efa",
        )
W
wangzhen38 已提交
127 128

        tdm_layer_counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
129 130 131 132
        # tree = TreeIndex("demo", path)
        file_name = os.path.join(
            self.temp_dir.name, "test_in_memory_dataset_tdm_sample_run.txt"
        )
W
wangzhen38 已提交
133
        with open(file_name, "w") as f:
134
            # data = "29 d 29 d 29 29 29 29 29 29 29 29 29 29 29 29\n"
W
wangzhen38 已提交
135 136 137 138 139 140 141
            data = "1 1 1 15 15 15\n"
            data += "1 1 1 15 15 15\n"
            f.write(data)

        slots = ["slot1", "slot2", "slot3"]
        slots_vars = []
        for slot in slots:
G
GGBond8488 已提交
142
            var = paddle.static.data(name=slot, shape=[-1, 1], dtype="int64")
W
wangzhen38 已提交
143 144 145
            slots_vars.append(var)

        dataset = paddle.distributed.InMemoryDataset()
146 147 148 149 150 151
        dataset.init(
            batch_size=1,
            pipe_command="cat",
            download_cmd="cat",
            use_var=slots_vars,
        )
W
wangzhen38 已提交
152
        dataset.set_filelist([file_name])
153 154
        # dataset.update_settings(pipe_command="cat")
        # dataset._init_distributed_settings(
W
wangzhen38 已提交
155 156 157 158 159 160
        #    parse_ins_id=True,
        #    parse_content=True,
        #    fea_eval=True,
        #    candidate_size=10000)

        dataset.load_into_memory()
161 162 163 164 165 166 167 168 169
        dataset.tdm_sample(
            'demo',
            tree_path=path,
            tdm_layer_counts=tdm_layer_counts,
            start_sample_layer=1,
            with_hierachy=False,
            seed=0,
            id_slot=2,
        )
W
wangzhen38 已提交
170
        self.assertTrue(dataset.get_shuffle_data_size() == 8)
1
123malin 已提交
171 172 173 174


if __name__ == '__main__':
    unittest.main()