diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index 1904e27bdf146894f3c51c32c055e222ebadb99f..3dc8bfd6e04fb3da2b217f84126fac8101102945 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -42,7 +42,7 @@ class TranspileTrainer(Trainer): namespace = "train.reader" class_name = "TrainReader" else: - dataloader = self.model._infer_data_loader + readerdataloader = self.model._infer_data_loader namespace = "evaluate.reader" class_name = "EvaluateReader" @@ -52,22 +52,22 @@ class TranspileTrainer(Trainer): reader = dataloader_instance.dataloader( reader_class, state, self._config_yaml) - debug_mode = envs.get_global_env("debug_mode", False, namespace) + reader_class = envs.lazy_instance_by_fliename(reader_class, class_name) + reader_ins = reader_class(self._config_yaml) + if hasattr(reader_ins, 'generate_batch_from_trainfiles'): + dataloader.set_sample_list_generator(reader) + else: + dataloader.set_sample_generator(reader, batch_size) + + debug_mode = envs.get_global_env("reader_debug_mode", False, namespace) if debug_mode: print("--- DataLoader Debug Mode Begin , show pre 10 data ---") - for idx, line in enumerate(reader): + for idx, line in enumerate(reader()): print(line) if idx >= 9: break print("--- DataLoader Debug Mode End , show pre 10 data ---") exit(0) - - reader_class = envs.lazy_instance_by_fliename(reader_class, class_name) - reader_ins = reader_class(self._config_yaml) - if hasattr(reader_ins, 'generate_batch_from_trainfiles'): - dataloader.set_sample_list_generator(reader) - else: - dataloader.set_sample_generator(reader, batch_size) return dataloader def _get_dataset(self, state="TRAIN"): @@ -109,7 +109,7 @@ class TranspileTrainer(Trainer): dataset.set_filelist(file_list) - debug_mode = envs.get_global_env("debug_mode", False, namespace) + debug_mode = envs.get_global_env("reader_debug_mode", False, namespace) if debug_mode: print( "--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0])) diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index 27eb639190cf98ffe275d0dd49514346ceae11b0..adda027fb23e97bc83b8a2548317d903dc7f6e7b 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -24,7 +24,7 @@ train: batch_size: 2 class: "{workspace}/../criteo_reader.py" train_data_path: "{workspace}/data/train" - debug_mode: False + reader_debug_mode: False model: models: "{workspace}/model.py" diff --git a/models/recall/tdm/config.yaml b/models/recall/tdm/config.yaml index 80f0678bca1c8d8277e20fa6f999dc434bbd97d4..2b2ec9f9ce83b728a1b3ce18fb0e01a6a06f4e11 100755 --- a/models/recall/tdm/config.yaml +++ b/models/recall/tdm/config.yaml @@ -25,6 +25,7 @@ train: class: "{workspace}/tdm_reader.py" train_data_path: "{workspace}/data/train" test_data_path: "{workspace}/data/test" + reader_debug_mode: False model: models: "{workspace}/model.py" diff --git a/models/recall/tdm/tdm_evaluate_reader.py b/models/recall/tdm/tdm_evaluate_reader.py index cb32426685c16a2917c61c4be4d928892cb5c238..844e441fbda303ea4a5ab3c0f549711579dbf5d5 100644 --- a/models/recall/tdm/tdm_evaluate_reader.py +++ b/models/recall/tdm/tdm_evaluate_reader.py @@ -33,7 +33,8 @@ class EvaluateReader(Reader): This function needs to be implemented by the user, based on data format """ features = (line.strip('\n')).split('\t') - input_emb = map(float, features[0].split(' ')) + input_emb = features[0].split(' ') + input_emb = [float(i) for i in input_emb] feature_name = ["input_emb"] yield zip(feature_name, [input_emb]) diff --git a/models/recall/tdm/tdm_reader.py b/models/recall/tdm/tdm_reader.py index 5a24fbb40b350f843e37876de7707b69c07f5555..0b8ada9ea4d695aafd38c1e87831c9939e483618 100755 --- a/models/recall/tdm/tdm_reader.py +++ b/models/recall/tdm/tdm_reader.py @@ -33,7 +33,8 @@ class TrainReader(Reader): This function needs to be implemented by the user, based on data format """ features = (line.strip('\n')).split('\t') - input_emb = map(float, features[0].split(' ')) + input_emb = features[0].split(' ') + input_emb = [float(i) for i in input_emb] item_label = [int(features[1])] feature_name = ["input_emb", "item_label"]