From 33fc1750e524d61bfa0e0805ef84c9254f5a285e Mon Sep 17 00:00:00 2001 From: chengmo Date: Wed, 13 May 2020 20:05:32 +0800 Subject: [PATCH] add reader debug mode --- core/trainers/transpiler_trainer.py | 22 +++++++++++----------- models/rank/dnn/config.yaml | 2 +- models/recall/tdm/config.yaml | 1 + models/recall/tdm/tdm_evaluate_reader.py | 3 ++- models/recall/tdm/tdm_reader.py | 3 ++- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index 1904e27b..3dc8bfd6 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -42,7 +42,7 @@ class TranspileTrainer(Trainer): namespace = "train.reader" class_name = "TrainReader" else: - dataloader = self.model._infer_data_loader + readerdataloader = self.model._infer_data_loader namespace = "evaluate.reader" class_name = "EvaluateReader" @@ -52,22 +52,22 @@ class TranspileTrainer(Trainer): reader = dataloader_instance.dataloader( reader_class, state, self._config_yaml) - debug_mode = envs.get_global_env("debug_mode", False, namespace) + reader_class = envs.lazy_instance_by_fliename(reader_class, class_name) + reader_ins = reader_class(self._config_yaml) + if hasattr(reader_ins, 'generate_batch_from_trainfiles'): + dataloader.set_sample_list_generator(reader) + else: + dataloader.set_sample_generator(reader, batch_size) + + debug_mode = envs.get_global_env("reader_debug_mode", False, namespace) if debug_mode: print("--- DataLoader Debug Mode Begin , show pre 10 data ---") - for idx, line in enumerate(reader): + for idx, line in enumerate(reader()): print(line) if idx >= 9: break print("--- DataLoader Debug Mode End , show pre 10 data ---") exit(0) - - reader_class = envs.lazy_instance_by_fliename(reader_class, class_name) - reader_ins = reader_class(self._config_yaml) - if hasattr(reader_ins, 'generate_batch_from_trainfiles'): - dataloader.set_sample_list_generator(reader) - else: - dataloader.set_sample_generator(reader, batch_size) return dataloader def _get_dataset(self, state="TRAIN"): @@ -109,7 +109,7 @@ class TranspileTrainer(Trainer): dataset.set_filelist(file_list) - debug_mode = envs.get_global_env("debug_mode", False, namespace) + debug_mode = envs.get_global_env("reader_debug_mode", False, namespace) if debug_mode: print( "--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0])) diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index 27eb6391..adda027f 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -24,7 +24,7 @@ train: batch_size: 2 class: "{workspace}/../criteo_reader.py" train_data_path: "{workspace}/data/train" - debug_mode: False + reader_debug_mode: False model: models: "{workspace}/model.py" diff --git a/models/recall/tdm/config.yaml b/models/recall/tdm/config.yaml index 80f0678b..2b2ec9f9 100755 --- a/models/recall/tdm/config.yaml +++ b/models/recall/tdm/config.yaml @@ -25,6 +25,7 @@ train: class: "{workspace}/tdm_reader.py" train_data_path: "{workspace}/data/train" test_data_path: "{workspace}/data/test" + reader_debug_mode: False model: models: "{workspace}/model.py" diff --git a/models/recall/tdm/tdm_evaluate_reader.py b/models/recall/tdm/tdm_evaluate_reader.py index cb324266..844e441f 100644 --- a/models/recall/tdm/tdm_evaluate_reader.py +++ b/models/recall/tdm/tdm_evaluate_reader.py @@ -33,7 +33,8 @@ class EvaluateReader(Reader): This function needs to be implemented by the user, based on data format """ features = (line.strip('\n')).split('\t') - input_emb = map(float, features[0].split(' ')) + input_emb = features[0].split(' ') + input_emb = [float(i) for i in input_emb] feature_name = ["input_emb"] yield zip(feature_name, [input_emb]) diff --git a/models/recall/tdm/tdm_reader.py b/models/recall/tdm/tdm_reader.py index 5a24fbb4..0b8ada9e 100755 --- a/models/recall/tdm/tdm_reader.py +++ b/models/recall/tdm/tdm_reader.py @@ -33,7 +33,8 @@ class TrainReader(Reader): This function needs to be implemented by the user, based on data format """ features = (line.strip('\n')).split('\t') - input_emb = map(float, features[0].split(' ')) + input_emb = features[0].split(' ') + input_emb = [float(i) for i in input_emb] item_label = [int(features[1])] feature_name = ["input_emb", "item_label"] -- GitLab