test_dataset_consistency_inspection.py 25.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for Dataset consistency insepection of use_var_list and data_generator.
"""

from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.compat as cpt
import paddle.fluid.core as core
import numpy as np
import random
import math
import os
import shutil
28
import tempfile
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
import unittest
import paddle.fluid.incubate.data_generator as dg

#paddle.enable_static()
# fluid.disable_dygraph()
fluid.disable_dygraph()
url_schema_len = 5
query_schema = [
    'Q_query_basic', 'Q_query_phrase', 'Q_quq', 'Q_timelevel',
    'Q_context_title_basic1', 'Q_context_title_basic2',
    'Q_context_title_basic3', 'Q_context_title_basic4',
    'Q_context_title_basic5', 'Q_context_title_phrase1',
    'Q_context_title_phrase2', 'Q_context_title_phrase3',
    'Q_context_title_phrase4', 'Q_context_title_phrase5', 'Q_context_site1',
    'Q_context_site2', 'Q_context_site3', 'Q_context_site4', 'Q_context_site5'
]


class CTRDataset(dg.MultiSlotDataGenerator):
48

49 50 51 52
    def __init__(self, mode):
        self.test = mode

    def generate_sample(self, line):
53

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
        def reader():
            ins = line.strip().split(';')
            label_pos_num = int(ins[1].split(' ')[0])
            label_neg_num = int(ins[1].split(' ')[1])

            #query fea parse
            bias = 2
            query_len = 0
            sparse_query_feature = []
            for index in range(len(query_schema)):
                pos = index + bias
                sparse_query_feature.append(
                    [int(x) for x in ins[pos].split(' ')])
                if index == 0:
                    query_len = len(ins[pos].split(' '))
                    query_len = 1.0 / (1 + pow(2.7182818, 3 - 1.0 * query_len))

            #positive url fea parse
            bias = 2 + len(query_schema)
            pos_url_feas = []
            pos_click_feas = []
            pos_context_feas = []
            for k in range(label_pos_num):
                pos_url_fea = []
                pos = 0
                for index in range(url_schema_len - 1):
                    pos = bias + k * (url_schema_len) + index
                    pos_url_fea.append([int(x) for x in ins[pos].split(' ')])
                #click info
                if (ins[pos + 1] == ''):
                    continue
                item = ins[pos + 1].split(' ')
                if len(item) != 17:
                    continue
                stat_fea = [[max(float(item[i]), 0.0)] for i in range(len(item)) \
                            if not (i == 5 or i == 9 or i == 13 or i == 14 or i ==15 or i ==16)]
                pos_url_feas.append(pos_url_fea)
                pos_click_feas.append(stat_fea)

                query_serach = float(item[5])
                if query_serach > 0.0:
                    query_serach = min(math.log(query_serach), 10.0) / 10.0
                pos_context_fea = [[query_serach], [query_len]]
                pos_context_feas.append(pos_context_fea)

            #negative url fea parse
            bias = 2 + len(query_schema) + label_pos_num * (url_schema_len)
            neg_url_feas = []
            neg_click_feas = []
            neg_context_feas = []
            for k in range(label_neg_num):
                neg_url_fea = []
                pos = 0
                for index in range(url_schema_len - 1):
                    pos = bias + k * (url_schema_len) + index
                    neg_url_fea.append([int(x) for x in ins[pos].split(' ')])
                if (ins[pos + 1] == ''):
                    continue
                item = ins[pos + 1].split(' ')
                #zdf_tmp
                if len(item) != 17:
                    continue
                    #print ins[pos + 1]
                stat_fea = [[max(float(item[i]), 0.0)] for i in range(len(item)) \
                            if not (i == 5 or i == 9 or i == 13 or i == 14 or i == 15 or i == 16)]
                neg_click_feas.append(stat_fea)
                neg_url_feas.append(neg_url_fea)

                query_serach = float(item[5])
                if query_serach > 0.0:
                    query_serach = min(math.log(query_serach), 10.0) / 10.0
                neg_context_fea = [[query_serach], [query_len]]
                neg_context_feas.append(neg_context_fea)

            #make train data
            if self.test == 1:
                for p in range(len(pos_url_feas)):
                    # feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    yield zip(feature_name, [[1]] + sparse_query_feature +
                              pos_url_fea + pos_click_fea + pos_context_fea +
                              pos_url_fea + pos_click_fea + pos_context_fea)
                for n in range(len(neg_url_feas)):
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    neg_url_fea = neg_url_feas[n]
                    neg_click_fea = neg_click_feas[n]
                    neg_context_fea = neg_context_feas[n]
                    yield zip(feature_name, [[0]] + sparse_query_feature +
                              neg_url_fea + neg_click_fea + neg_context_fea +
                              neg_url_fea + neg_click_fea + neg_context_fea)
            elif self.test == 0:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[1]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 2:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[1], [2]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 3:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[1], [2.0]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 4:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, [[], [2.0]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))
            elif self.test == 5:
                for p in range(len(pos_url_feas)):
                    #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:]
                    feature_name = ["click"]
                    for i in range(1, 54):
                        feature_name.append(str(i))
                    #print("#######")
                    #print(feature_name)
                    #print("#######")
                    pos_url_fea = pos_url_feas[p]
                    pos_click_fea = pos_click_feas[p]
                    pos_context_fea = pos_context_feas[p]
                    for n in range(len(neg_url_feas)):
                        # prob = get_rand()
                        # if prob < sample_rate:
                        neg_url_fea = neg_url_feas[n]
                        neg_click_fea = neg_click_feas[n]
                        neg_context_fea = neg_context_feas[n]
                        #print("q:", query_feas)
                        #print("pos:", pos_url_fea)
                        #print("neg:", neg_url_fea)
                        # yield zip(feature_name[:3], sparse_query_feature[:3])
                        yield list(zip(feature_name, sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \
                            neg_url_fea + neg_click_fea + neg_context_fea))

        return reader


class TestDataset(unittest.TestCase):
    """  TestCases for Dataset. """

    def setUp(self):
        pass
        # use_data_loader = False
        # epoch_num = 10
        # drop_last = False

    def test_var_consistency_insepection(self):
        """
        Testcase for InMemoryDataset of consistency insepection of use_var_list and data_generator.
        """
288 289 290 291 292

        temp_dir = tempfile.TemporaryDirectory()
        dump_a_path = os.path.join(temp_dir.name, 'test_run_with_dump_a.txt')

        with open(dump_a_path, "w") as f:
293 294 295 296 297 298 299 300
            # data = "\n"
            # data += "\n"
            data = "2 1;1 9;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;0;40000001;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000200;10000200;10063938;10000008;10000177;20002001 20001240 20001860 20003611 20010833 20000210 20000500 20000401 20000251 20012198 20001023 20000157;20002001 20001240 20001860 20003611 20012396 20000500 20002513 20012198 20001023 20000157;10000123;30000004;0.623 0.233 0.290 0.208 0.354 49.000 0.000 0.000 0.000 -1.000 0.569 0.679 0.733 53 17 2 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000047;30000004;0.067 0.000 0.161 0.005 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.378 0.043 0 6 0 0;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;10000200;30000001;0.407 0.111 0.196 0.095 0.181 49.000 0.000 0.000 0.000 -1.000 0.306 0.538 0.355 48 8 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;10000200;30000001;0.226 0.029 0.149 0.031 0.074 49.000 0.000 0.000 0.000 -1.000 0.220 0.531 0.286 26 6 0 0;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;10063938;30000001;0.250 0.019 0.138 0.012 0.027 49.000 0.000 0.000 0.000 -1.000 0.370 0.449 0.327 7 2 0 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000003;30000002;0.056 0.000 0.139 0.003 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.346 0.059 15 3 0 0;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;10000008;30000001;0.166 0.004 0.127 0.001 0.004 49.000 0.000 0.000 0.000 -1.000 0.103 0.417 0.394 10 3 0 0;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000177;30000001;0.094 0.008 0.157 0.012 0.059 49.000 0.000 0.000 0.000 -1.000 0.051 0.382 0.142 21 0 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;10000134;30000001;0.220 0.016 0.181 0.037 0.098 49.000 0.000 0.000 0.000 -1.000 0.192 0.453 0.199 17 1 0 0;20002001 20001240 20001860 20003611 20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000638;30000001;0.000 0.000 0.000 0.000 0.000 49.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0 0 0 0;\n"
            data += "2 1;1 11;20000025 20000404;20001923;20000002 20000157 20000028 20004205 20000500 20028809 20000571 20000007 20027523 20004940 20000651 20000043 20000051 20000520 20015398 20000066 20004720 20000070 20001648;40000001;20000025 20000404 20000571 20004940 20000001 20000017;20000025 20000404 20000029 20000500 20001408 20000404 20000001 20000017;0;0;0;20001923 20011130 20000027;20001923 20000029 20000500 20001408 20000404 20000027;0;0;0;10000005;10000005;0;0;0;20003316 20000392 20001979 20000474 20000025 20000194 20000025 20000404 20000019 20000109;20016528 20024913 20004748 20001923 20000019 20000109;10000015;30000002;0.572 0.043 0.401 0.352 0.562 32859.000 0.005 0.060 0.362 -1.000 0.448 0.673 0.222 16316 991 89 0;20000025 20000404 20000571 20004940 20000001 20000017;20001923 20011130 20000027;10000005;30000001;0.495 0.024 0.344 0.285 0.379 32859.000 0.002 0.050 0.362 -1.000 0.423 0.764 0.254 19929 896 72 0;20000202 20000026 20001314 20004289 20000025 20000404 20000451 20000089 20000007;20000202 20000026 20014094 20001314 20004289 20001923 20000451 20000089 20000007;10000035;30000003;0.133 0.006 0.162 0.042 0.174 32859.000 0.003 0.037 0.362 -1.000 0.363 0.542 0.122 14763 664 53 0;20000202 20000026 20001314 20004289 20000025 20000404;20000202 20000026 20014094 20001314 20004289 20001923;10000021;30000001;0.058 0.004 0.133 0.017 0.120 32859.000 0.000 0.006 0.362 -1.000 0.168 0.437 0.041 -1 -1 -1 -1;20000025 20000404 20000018 20012461 20001699 20000446 20000174 20000062 20000133 20003172 20000240 20007877 20067375 20000111 20000164 20001410 20000204 20016958;20001923 20000018 20012461 20001699 20007717 20000062 20000133 20003172 20000240 20007877 20067375 20000111 20000164 20001410 20000204 20016958;10000002;30000001;0.017 0.000 0.099 0.004 0.072 32859.000 0.000 0.009 0.362 -1.000 0.058 0.393 0.025 -1 -1 -1 -1;20000025 20000404;20001923;10000133;30000005;0.004 0.000 0.122 0.000 0.000 32859.000 0.000 0.000 0.362 -1.000 0.000 0.413 0.020 0 444 35 0;20000025 20000404;20001923;10005297;30000004;0.028 0.000 0.138 0.002 0.000 32859.000 0.000 0.000 0.362 -1.000 0.000 0.343 0.024 0 600 48 0;20000025 20000404;20001923;10000060;30000005;0.107 0.000 0.110 0.027 0.077 32859.000 0.000 0.005 0.362 -1.000 0.095 0.398 0.062 1338 491 39 0;20002960 20005534 20000043 20000025 20000404 20000025 20000007;20002960 20005534 20000043 20001923 20000025 20000007;10000020;30000003;0.041 0.000 0.122 0.012 0.101 32859.000 0.001 0.025 0.362 -1.000 0.302 0.541 0.065 9896 402 35 0;20000025 20000404 20000259 20000228 20000235 20000142;20001923 20000259 20000264 20000142;10000024;30000003;0.072 0.002 0.156 0.026 0.141 32859.000 0.002 0.032 0.362 -1.000 0.386 0.569 0.103 9896 364 35 0;20000025 20000404 20000029 20000500 20001408 20000404 20000001 20000017;20001923 20000029 20000500 20001408 20000404 20000027;10000005;30000001;0.328 0.006 0.179 0.125 0.181 32859.000 0.003 0.058 0.362 -1.000 0.300 0.445 0.141 9896 402 32 0;20000025 20000404;20001923;10012839;30000002;0.012 0.000 0.108 0.002 0.048 32859.000 0.000 0.000 0.362 -1.000 0.021 0.225 0.016 2207 120 12 0;\n"
            # data += ""
            f.write(data)

        slot_data = []
301 302 303 304 305
        label = fluid.layers.data(name="click",
                                  shape=[-1, 1],
                                  dtype="int64",
                                  lod_level=0,
                                  append_batch_size=False)
306 307 308 309 310 311
        slot_data.append(label)

        # sprase_query_feat_names
        len_sparse_query = 19
        for feat_name in range(1, len_sparse_query + 1):
            slot_data.append(
312 313 314 315
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='int64',
                                  lod_level=1))
316

317
        # sparse_url_feat_names
318 319
        for feat_name in range(len_sparse_query + 1, len_sparse_query + 5):
            slot_data.append(
320 321 322 323
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='int64',
                                  lod_level=1))
324 325 326 327

        # dense_feat_names
        for feat_name in range(len_sparse_query + 5, len_sparse_query + 16):
            slot_data.append(
328 329 330
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
331 332 333 334

        # context_feat_namess
        for feat_name in range(len_sparse_query + 16, len_sparse_query + 18):
            slot_data.append(
335 336 337
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
338

339
        # neg sparse_url_feat_names
340 341
        for feat_name in range(len_sparse_query + 18, len_sparse_query + 22):
            slot_data.append(
342 343 344 345
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='int64',
                                  lod_level=1))
346 347 348 349

        # neg dense_feat_names
        for feat_name in range(len_sparse_query + 22, len_sparse_query + 33):
            slot_data.append(
350 351 352
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
353 354 355 356

        # neg context_feat_namess
        for feat_name in range(len_sparse_query + 33, len_sparse_query + 35):
            slot_data.append(
357 358 359
                fluid.layers.data(name=str(feat_name),
                                  shape=[1],
                                  dtype='float32'))
360 361 362 363 364 365

        dataset = paddle.distributed.InMemoryDataset()

        print("========================================")
        generator_class = CTRDataset(mode=0)
        try:
366 367 368
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
369 370 371 372 373 374 375 376 377 378
            print("case 1: check passed!")
        except Exception as e:
            print("warning: catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=2)
        try:
379 380 381
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
382 383 384 385 386 387 388 389 390
        except Exception as e:
            print("warning: case 2 catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=3)
        try:
391 392 393
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
394 395 396 397 398 399 400 401 402
        except Exception as e:
            print("warning: case 3 catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=4)
        try:
403 404 405
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
406 407 408 409 410 411 412 413 414
        except Exception as e:
            print("warning: case 4 catch expected error")
            print(e)
        print("========================================")
        print("\n")

        print("========================================")
        generator_class = CTRDataset(mode=5)
        try:
415 416 417
            dataset._check_use_var_with_data_generator(slot_data,
                                                       generator_class,
                                                       dump_a_path)
418 419 420 421 422
        except Exception as e:
            print("warning: case 5 catch expected error")
            print(e)
        print("========================================")

423
        temp_dir.cleanup()
424 425 426 427


if __name__ == '__main__':
    unittest.main()