# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ TestCases for Dataset consistency insepection of use_var_list and data_generator. """ from __future__ import print_function import paddle import paddle.fluid as fluid import paddle.compat as cpt import paddle.fluid.core as core import numpy as np import random import math import os import shutil import tempfile import unittest import paddle.fluid.incubate.data_generator as dg #paddle.enable_static() # fluid.disable_dygraph() fluid.disable_dygraph() url_schema_len = 5 query_schema = [ 'Q_query_basic', 'Q_query_phrase', 'Q_quq', 'Q_timelevel', 'Q_context_title_basic1', 'Q_context_title_basic2', 'Q_context_title_basic3', 'Q_context_title_basic4', 'Q_context_title_basic5', 'Q_context_title_phrase1', 'Q_context_title_phrase2', 'Q_context_title_phrase3', 'Q_context_title_phrase4', 'Q_context_title_phrase5', 'Q_context_site1', 'Q_context_site2', 'Q_context_site3', 'Q_context_site4', 'Q_context_site5' ] class CTRDataset(dg.MultiSlotDataGenerator): def __init__(self, mode): self.test = mode def generate_sample(self, line): def reader(): ins = line.strip().split(';') label_pos_num = int(ins[1].split(' ')[0]) label_neg_num = int(ins[1].split(' ')[1]) #query fea parse bias = 2 query_len = 0 sparse_query_feature = [] for index in range(len(query_schema)): pos = index + bias sparse_query_feature.append( [int(x) for x in ins[pos].split(' ')]) if index == 0: query_len = len(ins[pos].split(' ')) query_len = 1.0 / (1 + pow(2.7182818, 3 - 1.0 * query_len)) #positive url fea parse bias = 2 + len(query_schema) pos_url_feas = [] pos_click_feas = [] pos_context_feas = [] for k in range(label_pos_num): pos_url_fea = [] pos = 0 for index in range(url_schema_len - 1): pos = bias + k * (url_schema_len) + index pos_url_fea.append([int(x) for x in ins[pos].split(' ')]) #click info if (ins[pos + 1] == ''): continue item = ins[pos + 1].split(' ') if len(item) != 17: continue stat_fea = [[max(float(item[i]), 0.0)] for i in range(len(item)) \ if not (i == 5 or i == 9 or i == 13 or i == 14 or i ==15 or i ==16)] pos_url_feas.append(pos_url_fea) pos_click_feas.append(stat_fea) query_serach = float(item[5]) if query_serach > 0.0: query_serach = min(math.log(query_serach), 10.0) / 10.0 pos_context_fea = [[query_serach], [query_len]] pos_context_feas.append(pos_context_fea) #negative url fea parse bias = 2 + len(query_schema) + label_pos_num * (url_schema_len) neg_url_feas = [] neg_click_feas = [] neg_context_feas = [] for k in range(label_neg_num): neg_url_fea = [] pos = 0 for index in range(url_schema_len - 1): pos = bias + k * (url_schema_len) + index neg_url_fea.append([int(x) for x in ins[pos].split(' ')]) if (ins[pos + 1] == ''): continue item = ins[pos + 1].split(' ') #zdf_tmp if len(item) != 17: continue #print ins[pos + 1] stat_fea = [[max(float(item[i]), 0.0)] for i in range(len(item)) \ if not (i == 5 or i == 9 or i == 13 or i == 14 or i == 15 or i == 16)] neg_click_feas.append(stat_fea) neg_url_feas.append(neg_url_fea) query_serach = float(item[5]) if query_serach > 0.0: query_serach = min(math.log(query_serach), 10.0) / 10.0 neg_context_fea = [[query_serach], [query_len]] neg_context_feas.append(neg_context_fea) #make train data if self.test == 1: for p in range(len(pos_url_feas)): # feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) pos_url_fea = pos_url_feas[p] pos_click_fea = pos_click_feas[p] pos_context_fea = pos_context_feas[p] yield zip(feature_name, [[1]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + pos_url_fea + pos_click_fea + pos_context_fea) for n in range(len(neg_url_feas)): feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) neg_url_fea = neg_url_feas[n] neg_click_fea = neg_click_feas[n] neg_context_fea = neg_context_feas[n] yield zip(feature_name, [[0]] + sparse_query_feature + neg_url_fea + neg_click_fea + neg_context_fea + neg_url_fea + neg_click_fea + neg_context_fea) elif self.test == 0: for p in range(len(pos_url_feas)): #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:] feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) #print("#######") #print(feature_name) #print("#######") pos_url_fea = pos_url_feas[p] pos_click_fea = pos_click_feas[p] pos_context_fea = pos_context_feas[p] for n in range(len(neg_url_feas)): # prob = get_rand() # if prob < sample_rate: neg_url_fea = neg_url_feas[n] neg_click_fea = neg_click_feas[n] neg_context_fea = neg_context_feas[n] #print("q:", query_feas) #print("pos:", pos_url_fea) #print("neg:", neg_url_fea) # yield zip(feature_name[:3], sparse_query_feature[:3]) yield list(zip(feature_name, [[1]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \ neg_url_fea + neg_click_fea + neg_context_fea)) elif self.test == 2: for p in range(len(pos_url_feas)): #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:] feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) #print("#######") #print(feature_name) #print("#######") pos_url_fea = pos_url_feas[p] pos_click_fea = pos_click_feas[p] pos_context_fea = pos_context_feas[p] for n in range(len(neg_url_feas)): # prob = get_rand() # if prob < sample_rate: neg_url_fea = neg_url_feas[n] neg_click_fea = neg_click_feas[n] neg_context_fea = neg_context_feas[n] #print("q:", query_feas) #print("pos:", pos_url_fea) #print("neg:", neg_url_fea) # yield zip(feature_name[:3], sparse_query_feature[:3]) yield list(zip(feature_name, [[1], [2]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \ neg_url_fea + neg_click_fea + neg_context_fea)) elif self.test == 3: for p in range(len(pos_url_feas)): #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:] feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) #print("#######") #print(feature_name) #print("#######") pos_url_fea = pos_url_feas[p] pos_click_fea = pos_click_feas[p] pos_context_fea = pos_context_feas[p] for n in range(len(neg_url_feas)): # prob = get_rand() # if prob < sample_rate: neg_url_fea = neg_url_feas[n] neg_click_fea = neg_click_feas[n] neg_context_fea = neg_context_feas[n] #print("q:", query_feas) #print("pos:", pos_url_fea) #print("neg:", neg_url_fea) # yield zip(feature_name[:3], sparse_query_feature[:3]) yield list(zip(feature_name, [[1], [2.0]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \ neg_url_fea + neg_click_fea + neg_context_fea)) elif self.test == 4: for p in range(len(pos_url_feas)): #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:] feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) #print("#######") #print(feature_name) #print("#######") pos_url_fea = pos_url_feas[p] pos_click_fea = pos_click_feas[p] pos_context_fea = pos_context_feas[p] for n in range(len(neg_url_feas)): # prob = get_rand() # if prob < sample_rate: neg_url_fea = neg_url_feas[n] neg_click_fea = neg_click_feas[n] neg_context_fea = neg_context_feas[n] #print("q:", query_feas) #print("pos:", pos_url_fea) #print("neg:", neg_url_fea) # yield zip(feature_name[:3], sparse_query_feature[:3]) yield list(zip(feature_name, [[], [2.0]] + sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \ neg_url_fea + neg_click_fea + neg_context_fea)) elif self.test == 5: for p in range(len(pos_url_feas)): #feature_name = ["click"] + query_schema + url_schema[:4] + click_info_schema[:11] + context_schema[:2] + url_schema[4:] + click_info_schema[11:] + context_schema[2:] feature_name = ["click"] for i in range(1, 54): feature_name.append(str(i)) #print("#######") #print(feature_name) #print("#######") pos_url_fea = pos_url_feas[p] pos_click_fea = pos_click_feas[p] pos_context_fea = pos_context_feas[p] for n in range(len(neg_url_feas)): # prob = get_rand() # if prob < sample_rate: neg_url_fea = neg_url_feas[n] neg_click_fea = neg_click_feas[n] neg_context_fea = neg_context_feas[n] #print("q:", query_feas) #print("pos:", pos_url_fea) #print("neg:", neg_url_fea) # yield zip(feature_name[:3], sparse_query_feature[:3]) yield list(zip(feature_name, sparse_query_feature + pos_url_fea + pos_click_fea + pos_context_fea + \ neg_url_fea + neg_click_fea + neg_context_fea)) return reader class TestDataset(unittest.TestCase): """ TestCases for Dataset. """ def setUp(self): pass # use_data_loader = False # epoch_num = 10 # drop_last = False def test_var_consistency_insepection(self): """ Testcase for InMemoryDataset of consistency insepection of use_var_list and data_generator. """ temp_dir = tempfile.TemporaryDirectory() dump_a_path = os.path.join(temp_dir.name, 'test_run_with_dump_a.txt') with open(dump_a_path, "w") as f: # data = "\n" # data += "\n" data = "2 1;1 9;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;0;40000001;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000200;10000200;10063938;10000008;10000177;20002001 20001240 20001860 20003611 20010833 20000210 20000500 20000401 20000251 20012198 20001023 20000157;20002001 20001240 20001860 20003611 20012396 20000500 20002513 20012198 20001023 20000157;10000123;30000004;0.623 0.233 0.290 0.208 0.354 49.000 0.000 0.000 0.000 -1.000 0.569 0.679 0.733 53 17 2 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000047;30000004;0.067 0.000 0.161 0.005 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.378 0.043 0 6 0 0;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;10000200;30000001;0.407 0.111 0.196 0.095 0.181 49.000 0.000 0.000 0.000 -1.000 0.306 0.538 0.355 48 8 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;10000200;30000001;0.226 0.029 0.149 0.031 0.074 49.000 0.000 0.000 0.000 -1.000 0.220 0.531 0.286 26 6 0 0;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;10063938;30000001;0.250 0.019 0.138 0.012 0.027 49.000 0.000 0.000 0.000 -1.000 0.370 0.449 0.327 7 2 0 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000003;30000002;0.056 0.000 0.139 0.003 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.346 0.059 15 3 0 0;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;10000008;30000001;0.166 0.004 0.127 0.001 0.004 49.000 0.000 0.000 0.000 -1.000 0.103 0.417 0.394 10 3 0 0;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000177;30000001;0.094 0.008 0.157 0.012 0.059 49.000 0.000 0.000 0.000 -1.000 0.051 0.382 0.142 21 0 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;10000134;30000001;0.220 0.016 0.181 0.037 0.098 49.000 0.000 0.000 0.000 -1.000 0.192 0.453 0.199 17 1 0 0;20002001 20001240 20001860 20003611 20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000638;30000001;0.000 0.000 0.000 0.000 0.000 49.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0 0 0 0;\n" data += "2 1;1 11;20000025 20000404;20001923;20000002 20000157 20000028 20004205 20000500 20028809 20000571 20000007 20027523 20004940 20000651 20000043 20000051 20000520 20015398 20000066 20004720 20000070 20001648;40000001;20000025 20000404 20000571 20004940 20000001 20000017;20000025 20000404 20000029 20000500 20001408 20000404 20000001 20000017;0;0;0;20001923 20011130 20000027;20001923 20000029 20000500 20001408 20000404 20000027;0;0;0;10000005;10000005;0;0;0;20003316 20000392 20001979 20000474 20000025 20000194 20000025 20000404 20000019 20000109;20016528 20024913 20004748 20001923 20000019 20000109;10000015;30000002;0.572 0.043 0.401 0.352 0.562 32859.000 0.005 0.060 0.362 -1.000 0.448 0.673 0.222 16316 991 89 0;20000025 20000404 20000571 20004940 20000001 20000017;20001923 20011130 20000027;10000005;30000001;0.495 0.024 0.344 0.285 0.379 32859.000 0.002 0.050 0.362 -1.000 0.423 0.764 0.254 19929 896 72 0;20000202 20000026 20001314 20004289 20000025 20000404 20000451 20000089 20000007;20000202 20000026 20014094 20001314 20004289 20001923 20000451 20000089 20000007;10000035;30000003;0.133 0.006 0.162 0.042 0.174 32859.000 0.003 0.037 0.362 -1.000 0.363 0.542 0.122 14763 664 53 0;20000202 20000026 20001314 20004289 20000025 20000404;20000202 20000026 20014094 20001314 20004289 20001923;10000021;30000001;0.058 0.004 0.133 0.017 0.120 32859.000 0.000 0.006 0.362 -1.000 0.168 0.437 0.041 -1 -1 -1 -1;20000025 20000404 20000018 20012461 20001699 20000446 20000174 20000062 20000133 20003172 20000240 20007877 20067375 20000111 20000164 20001410 20000204 20016958;20001923 20000018 20012461 20001699 20007717 20000062 20000133 20003172 20000240 20007877 20067375 20000111 20000164 20001410 20000204 20016958;10000002;30000001;0.017 0.000 0.099 0.004 0.072 32859.000 0.000 0.009 0.362 -1.000 0.058 0.393 0.025 -1 -1 -1 -1;20000025 20000404;20001923;10000133;30000005;0.004 0.000 0.122 0.000 0.000 32859.000 0.000 0.000 0.362 -1.000 0.000 0.413 0.020 0 444 35 0;20000025 20000404;20001923;10005297;30000004;0.028 0.000 0.138 0.002 0.000 32859.000 0.000 0.000 0.362 -1.000 0.000 0.343 0.024 0 600 48 0;20000025 20000404;20001923;10000060;30000005;0.107 0.000 0.110 0.027 0.077 32859.000 0.000 0.005 0.362 -1.000 0.095 0.398 0.062 1338 491 39 0;20002960 20005534 20000043 20000025 20000404 20000025 20000007;20002960 20005534 20000043 20001923 20000025 20000007;10000020;30000003;0.041 0.000 0.122 0.012 0.101 32859.000 0.001 0.025 0.362 -1.000 0.302 0.541 0.065 9896 402 35 0;20000025 20000404 20000259 20000228 20000235 20000142;20001923 20000259 20000264 20000142;10000024;30000003;0.072 0.002 0.156 0.026 0.141 32859.000 0.002 0.032 0.362 -1.000 0.386 0.569 0.103 9896 364 35 0;20000025 20000404 20000029 20000500 20001408 20000404 20000001 20000017;20001923 20000029 20000500 20001408 20000404 20000027;10000005;30000001;0.328 0.006 0.179 0.125 0.181 32859.000 0.003 0.058 0.362 -1.000 0.300 0.445 0.141 9896 402 32 0;20000025 20000404;20001923;10012839;30000002;0.012 0.000 0.108 0.002 0.048 32859.000 0.000 0.000 0.362 -1.000 0.021 0.225 0.016 2207 120 12 0;\n" # data += "" f.write(data) slot_data = [] label = fluid.layers.data(name="click", shape=[-1, 1], dtype="int64", lod_level=0, append_batch_size=False) slot_data.append(label) # sprase_query_feat_names len_sparse_query = 19 for feat_name in range(1, len_sparse_query + 1): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='int64', lod_level=1)) # sparse_url_feat_names for feat_name in range(len_sparse_query + 1, len_sparse_query + 5): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='int64', lod_level=1)) # dense_feat_names for feat_name in range(len_sparse_query + 5, len_sparse_query + 16): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='float32')) # context_feat_namess for feat_name in range(len_sparse_query + 16, len_sparse_query + 18): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='float32')) # neg sparse_url_feat_names for feat_name in range(len_sparse_query + 18, len_sparse_query + 22): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='int64', lod_level=1)) # neg dense_feat_names for feat_name in range(len_sparse_query + 22, len_sparse_query + 33): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='float32')) # neg context_feat_namess for feat_name in range(len_sparse_query + 33, len_sparse_query + 35): slot_data.append( fluid.layers.data(name=str(feat_name), shape=[1], dtype='float32')) dataset = paddle.distributed.InMemoryDataset() print("========================================") generator_class = CTRDataset(mode=0) try: dataset._check_use_var_with_data_generator(slot_data, generator_class, dump_a_path) print("case 1: check passed!") except Exception as e: print("warning: catch expected error") print(e) print("========================================") print("\n") print("========================================") generator_class = CTRDataset(mode=2) try: dataset._check_use_var_with_data_generator(slot_data, generator_class, dump_a_path) except Exception as e: print("warning: case 2 catch expected error") print(e) print("========================================") print("\n") print("========================================") generator_class = CTRDataset(mode=3) try: dataset._check_use_var_with_data_generator(slot_data, generator_class, dump_a_path) except Exception as e: print("warning: case 3 catch expected error") print(e) print("========================================") print("\n") print("========================================") generator_class = CTRDataset(mode=4) try: dataset._check_use_var_with_data_generator(slot_data, generator_class, dump_a_path) except Exception as e: print("warning: case 4 catch expected error") print(e) print("========================================") print("\n") print("========================================") generator_class = CTRDataset(mode=5) try: dataset._check_use_var_with_data_generator(slot_data, generator_class, dump_a_path) except Exception as e: print("warning: case 5 catch expected error") print(e) print("========================================") temp_dir.cleanup() if __name__ == '__main__': unittest.main()