test_dataset_consistency_inspection.py 25.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for Dataset consistency insepection of use_var_list and data_generator.
"""

import paddle
import paddle.fluid as fluid
import math
import os
22
import tempfile
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
import unittest
import paddle.fluid.incubate.data_generator as dg

#paddle.enable_static()
# fluid.disable_dygraph()
fluid.disable_dygraph()
url_schema_len = 5
query_schema = [
    'Q_query_basic', 'Q_query_phrase', 'Q_quq', 'Q_timelevel',
    'Q_context_title_basic1', 'Q_context_title_basic2',
    'Q_context_title_basic3', 'Q_context_title_basic4',
    'Q_context_title_basic5', 'Q_context_title_phrase1',
    'Q_context_title_phrase2', 'Q_context_title_phrase3',
    'Q_context_title_phrase4', 'Q_context_title_phrase5', 'Q_context_site1',
    'Q_context_site2', 'Q_context_site3', 'Q_context_site4', 'Q_context_site5'
]


class CTRDataset(dg.MultiSlotDataGenerator):
42

43 44 45 46
    def __init__(self, mode):
        self.test = mode

    def generate_sample(self, line):
47

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
        def reader():
            ins = line.strip().split(';')
            label_pos_num = int(ins[1].split(' ')[0])
            label_neg_num = int(ins[1].split(' ')[1])

            #query fea parse
            bias = 2
            query_len = 0
            sparse_query_feature = []
            for index in range(len(query_schema)):
                pos = index + bias
                sparse_query_feature.append(
                    [int(x) for x in ins[pos].split(' ')])
                if index == 0:
                    query_len = len(ins[pos].split(' '))
                    query_len = 1.0 / (1 + pow(2.7182818, 3 - 1.0 * query_len))

            #positive url fea parse
            bias = 2 + len(query_schema)
            pos_url_feas = []
            pos_click_feas = []
            pos_context_feas = []
            for k in range(label_pos_num):
                pos_url_fea = []
                pos = 0
                for index in range(url_schema_len - 1):
                    pos = bias + k * (url_schema_len) + index
                    pos_url_fea.append([int(x) for x in ins[pos].split(' ')])
                #click info
                if (ins[pos + 1] == ''):
                    continue
                item = ins[pos + 1].split(' ')
                if len(item) != 17:
                    continue
                stat_fea = [[max(float(item[i]), 0.0)] for i in range(len(item)) \
                            if not (i == 5 or i == 9 or i == 13 or i == 14 or i ==15 or i ==16)]
                pos_url_feas.append(pos_url_fea)
                pos_click_feas.append(stat_fea)

                query_serach = float(item[5])
                if query_serach > 0.0:
                    query_serach = min(math.log(query_serach), 10.0) / 10.0
                pos_context_fea = [[query_serach], [query_len]]
                pos_context_feas.append(pos_context_fea)

            #negative url fea parse
            bias = 2 + len(query_schema) + label_pos_num * (url_schema_len)
            neg_url_feas = []
            neg_click_feas = []
            neg_context_feas = []
            for k in range(label_neg_num):
                neg_url_fea = []
                pos = 0
                for index in range(url_schema_len - 1):
                    pos = bias + k * (url_schema_len) + index
                    neg_url_fea.append([int(x) for x in ins[pos].split(' ')])
                if (ins[pos + 1] == ''):
                    continue
                item = ins[pos + 1].split(' ')
                #zdf_tmp
                if len(item) != 17:
                    continue
                    #print ins[pos + 1]
                stat_fea = [[max(float(item[i]), 0.0)] for i in range(len(item)) \
                            if not (i == 5 or i == 9 or i == 13 or i == 14 or i == 15 or i == 16)]
                neg_click_feas.append(stat_fea)
                neg_url_feas.append(neg_url_fea)

                query_serach = float(item[5])
                if query_serach > 0.0:
                    query_serach = min(math.log(query_serach), 10.0) / 10.0
                neg_context_fea = [[query_serach], [query_len]]
                neg_context_feas.append(neg_context_fea)

            #make train data
            if self.test == 1:
                for p in range(len(pos_url_feas)):
                    # feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    yield zip(feature_name, [[1]] + sparse_query_feature +
                              pos_url_fea + pos_click_fea + pos_context_fea +
                              pos_url_fea + pos_click_fea + pos_context_fea)
                for n in range(len(neg_url_feas)):
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    neg_url_fea = neg_url_feas[n]
                    neg_click_fea = neg_click_feas[n]
                    neg_context_fea = neg_context_feas[n]
                    yield zip(feature_name, [[0]] + sparse_query_feature +
                              neg_url_fea + neg_click_fea + neg_context_fea +
                              neg_url_fea + neg_click_fea + neg_context_fea)
            elif self.test == 0:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[1]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 2:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[1], [2]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 3:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[1], [2.0]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 4:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[], [2.0]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 5:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))

        return reader


class TestDataset(unittest.TestCase):
    """  TestCases for Dataset. """

    def setUp(self):
        pass
        # use_data_loader = False
        # epoch_num = 10
        # drop_last = False

    def test_var_consistency_insepection(self):
        """
        Testcase for InMemoryDataset of consistency insepection of use_var_list and data_generator.
        """
282 283 284 285 286

        temp_dir = tempfile.TemporaryDirectory()
        dump_a_path = os.path.join(temp_dir.name, 'test_run_with_dump_a.txt')

        with open(dump_a_path, "w") as f:
287 288 289 290 291 292 293 294
            # data = "\n"
            # data += "\n"
            data = "2 1;1 9;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;0;40000001;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000200;10000200;10063938;10000008;10000177;20002001 20001240 20001860 20003611 20010833 20000210 20000500 20000401 20000251 20012198 20001023 20000157;20002001 20001240 20001860 20003611 20012396 20000500 20002513 20012198 20001023 20000157;10000123;30000004;0.623 0.233 0.290 0.208 0.354 49.000 0.000 0.000 0.000 -1.000 0.569 0.679 0.733 53 17 2 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000047;30000004;0.067 0.000 0.161 0.005 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.378 0.043 0 6 0 0;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;10000200;30000001;0.407 0.111 0.196 0.095 0.181 49.000 0.000 0.000 0.000 -1.000 0.306 0.538 0.355 48 8 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;10000200;30000001;0.226 0.029 0.149 0.031 0.074 49.000 0.000 0.000 0.000 -1.000 0.220 0.531 0.286 26 6 0 0;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;10063938;30000001;0.250 0.019 0.138 0.012 0.027 49.000 0.000 0.000 0.000 -1.000 0.370 0.449 0.327 7 2 0 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000003;30000002;0.056 0.000 0.139 0.003 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.346 0.059 15 3 0 0;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;10000008;30000001;0.166 0.004 0.127 0.001 0.004 49.000 0.000 0.000 0.000 -1.000 0.103 0.417 0.394 10 3 0 0;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000177;30000001;0.094 0.008 0.157 0.012 0.059 49.000 0.000 0.000 0.000 -1.000 0.051 0.382 0.142 21 0 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;10000134;30000001;0.220 0.016 0.181 0.037 0.098 49.000 0.000 0.000 0.000 -1.000 0.192 0.453 0.199 17 1 0 0;20002001 20001240 20001860 20003611 20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000638;30000001;0.000 0.000 0.000 0.000 0.000 49.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0 0 0 0;\n"
            data += "2 1;1 11;20000025 20000404;20001923;20000002 20000157 20000028 20004205 20000500 20028809 20000571 20000007 20027523 20004940 20000651 20000043 20000051 20000520 20015398 20000066 20004720 20000070 20001648;40000001;20000025 20000404 20000571 20004940 20000001 20000017;20000025 20000404 20000029 20000500 20001408 20000404 20000001 20000017;0;0;0;20001923 20011130 20000027;20001923 20000029 20000500 20001408 20000404 20000027;0;0;0;10000005;10000005;0;0;0;20003316 20000392 20001979 20000474 20000025 20000194 20000025 20000404 20000019 20000109;20016528 20024913 20004748 20001923 20000019 20000109;10000015;30000002;0.572 0.043 0.401 0.352 0.562 32859.000 0.005 0.060 0.362 -1.000 0.448 0.673 0.222 16316 991 89 0;20000025 20000404 20000571 20004940 20000001 20000017;20001923 20011130 20000027;10000005;30000001;0.495 0.024 0.344 0.285 0.379 32859.000 0.002 0.050 0.362 -1.000 0.423 0.764 0.254 19929 896 72 0;20000202 20000026 20001314 20004289 20000025 20000404 20000451 20000089 20000007;20000202 20000026 20014094 20001314 20004289 20001923 20000451 20000089 20000007;10000035;30000003;0.133 0.006 0.162 0.042 0.174 32859.000 0.003 0.037 0.362 -1.000 0.363 0.542 0.122 14763 664 53 0;20000202 20000026 20001314 20004289 20000025 20000404;20000202 20000026 20014094 20001314 20004289 20001923;10000021;30000001;0.058 0.004 0.133 0.017 0.120 32859.000 0.000 0.006 0.362 -1.000 0.168 0.437 0.041 -1 -1 -1 -1;20000025 20000404 20000018 20012461 20001699 20000446 20000174 20000062 20000133 20003172 20000240 20007877 20067375 20000111 20000164 20001410 20000204 20016958;20001923 20000018 20012461 20001699 20007717 20000062 20000133 20003172 20000240 20007877 20067375 20000111 20000164 20001410 20000204 20016958;10000002;30000001;0.017 0.000 0.099 0.004 0.072 32859.000 0.000 0.009 0.362 -1.000 0.058 0.393 0.025 -1 -1 -1 -1;20000025 20000404;20001923;10000133;30000005;0.004 0.000 0.122 0.000 0.000 32859.000 0.000 0.000 0.362 -1.000 0.000 0.413 0.020 0 444 35 0;20000025 20000404;20001923;10005297;30000004;0.028 0.000 0.138 0.002 0.000 32859.000 0.000 0.000 0.362 -1.000 0.000 0.343 0.024 0 600 48 0;20000025 20000404;20001923;10000060;30000005;0.107 0.000 0.110 0.027 0.077 32859.000 0.000 0.005 0.362 -1.000 0.095 0.398 0.062 1338 491 39 0;20002960 20005534 20000043 20000025 20000404 20000025 20000007;20002960 20005534 20000043 20001923 20000025 20000007;10000020;30000003;0.041 0.000 0.122 0.012 0.101 32859.000 0.001 0.025 0.362 -1.000 0.302 0.541 0.065 9896 402 35 0;20000025 20000404 20000259 20000228 20000235 20000142;20001923 20000259 20000264 20000142;10000024;30000003;0.072 0.002 0.156 0.026 0.141 32859.000 0.002 0.032 0.362 -1.000 0.386 0.569 0.103 9896 364 35 0;20000025 20000404 20000029 20000500 20001408 20000404 20000001 20000017;20001923 20000029 20000500 20001408 20000404 20000027;10000005;30000001;0.328 0.006 0.179 0.125 0.181 32859.000 0.003 0.058 0.362 -1.000 0.300 0.445 0.141 9896 402 32 0;20000025 20000404;20001923;10012839;30000002;0.012 0.000 0.108 0.002 0.048 32859.000 0.000 0.000 0.362 -1.000 0.021 0.225 0.016 2207 120 12 0;\n"
            # data += ""
            f.write(data)

        slot_data = []
295 296 297 298 299
        label = fluid.layers.data(name="click",
                                  shape=[-1, 1],
                                  dtype="int64",
                                  lod_level=0,
                                  append_batch_size=False)
300 301 302 303 304 305
        slot_data.append(label)

        # sprase_query_feat_names
        len_sparse_query = 19
        for feat_name in range(1, len_sparse_query + 1):
            slot_data.append(
306 307 308 309
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='int64',
                                  lod_level=1))
310

311
        # sparse_url_feat_names
312 313
        for feat_name in range(len_sparse_query + 1, len_sparse_query + 5):
            slot_data.append(
314 315 316 317
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='int64',
                                  lod_level=1))
318 319 320 321

        # dense_feat_names
        for feat_name in range(len_sparse_query + 5, len_sparse_query + 16):
            slot_data.append(
322 323 324
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
325 326 327 328

        # context_feat_namess
        for feat_name in range(len_sparse_query + 16, len_sparse_query + 18):
            slot_data.append(
329 330 331
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
332

333
        # neg sparse_url_feat_names
334 335
        for feat_name in range(len_sparse_query + 18, len_sparse_query + 22):
            slot_data.append(
336 337 338 339
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='int64',
                                  lod_level=1))
340 341 342 343

        # neg dense_feat_names
        for feat_name in range(len_sparse_query + 22, len_sparse_query + 33):
            slot_data.append(
344 345 346
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
347 348 349 350

        # neg context_feat_namess
        for feat_name in range(len_sparse_query + 33, len_sparse_query + 35):
            slot_data.append(
351 352 353
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
354 355 356 357 358 359

        dataset = paddle.distributed.InMemoryDataset()

        print("========================================")
        generator_class = CTRDataset(mode=0)
        try:
360 361 362
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
363 364 365 366 367 368 369 370 371 372
            print("case 1: check passed!")
        except Exception as e:
            print("warning: catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=2)
        try:
373 374 375
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
376 377 378 379 380 381 382 383 384
        except Exception as e:
            print("warning: case 2 catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=3)
        try:
385 386 387
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
388 389 390 391 392 393 394 395 396
        except Exception as e:
            print("warning: case 3 catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=4)
        try:
397 398 399
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
400 401 402 403 404 405 406 407 408
        except Exception as e:
            print("warning: case 4 catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=5)
        try:
409 410 411
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
412 413 414 415 416
        except Exception as e:
            print("warning: case 5 catch expected error")
            print(e)
        print("========================================")

417
        temp_dir.cleanup()
418 419 420 421


if __name__ == '__main__':
    unittest.main()