提交 33fc1750 编写于 作者: C chengmo

add reader debug mode

上级 013c12ff
...@@ -42,7 +42,7 @@ class TranspileTrainer(Trainer): ...@@ -42,7 +42,7 @@ class TranspileTrainer(Trainer):
namespace = "train.reader" namespace = "train.reader"
class_name = "TrainReader" class_name = "TrainReader"
else: else:
dataloader = self.model._infer_data_loader readerdataloader = self.model._infer_data_loader
namespace = "evaluate.reader" namespace = "evaluate.reader"
class_name = "EvaluateReader" class_name = "EvaluateReader"
...@@ -52,22 +52,22 @@ class TranspileTrainer(Trainer): ...@@ -52,22 +52,22 @@ class TranspileTrainer(Trainer):
reader = dataloader_instance.dataloader( reader = dataloader_instance.dataloader(
reader_class, state, self._config_yaml) reader_class, state, self._config_yaml)
debug_mode = envs.get_global_env("debug_mode", False, namespace) reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
else:
dataloader.set_sample_generator(reader, batch_size)
debug_mode = envs.get_global_env("reader_debug_mode", False, namespace)
if debug_mode: if debug_mode:
print("--- DataLoader Debug Mode Begin , show pre 10 data ---") print("--- DataLoader Debug Mode Begin , show pre 10 data ---")
for idx, line in enumerate(reader): for idx, line in enumerate(reader()):
print(line) print(line)
if idx >= 9: if idx >= 9:
break break
print("--- DataLoader Debug Mode End , show pre 10 data ---") print("--- DataLoader Debug Mode End , show pre 10 data ---")
exit(0) exit(0)
reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
else:
dataloader.set_sample_generator(reader, batch_size)
return dataloader return dataloader
def _get_dataset(self, state="TRAIN"): def _get_dataset(self, state="TRAIN"):
...@@ -109,7 +109,7 @@ class TranspileTrainer(Trainer): ...@@ -109,7 +109,7 @@ class TranspileTrainer(Trainer):
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
debug_mode = envs.get_global_env("debug_mode", False, namespace) debug_mode = envs.get_global_env("reader_debug_mode", False, namespace)
if debug_mode: if debug_mode:
print( print(
"--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0])) "--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0]))
......
...@@ -24,7 +24,7 @@ train: ...@@ -24,7 +24,7 @@ train:
batch_size: 2 batch_size: 2
class: "{workspace}/../criteo_reader.py" class: "{workspace}/../criteo_reader.py"
train_data_path: "{workspace}/data/train" train_data_path: "{workspace}/data/train"
debug_mode: False reader_debug_mode: False
model: model:
models: "{workspace}/model.py" models: "{workspace}/model.py"
......
...@@ -25,6 +25,7 @@ train: ...@@ -25,6 +25,7 @@ train:
class: "{workspace}/tdm_reader.py" class: "{workspace}/tdm_reader.py"
train_data_path: "{workspace}/data/train" train_data_path: "{workspace}/data/train"
test_data_path: "{workspace}/data/test" test_data_path: "{workspace}/data/test"
reader_debug_mode: False
model: model:
models: "{workspace}/model.py" models: "{workspace}/model.py"
......
...@@ -33,7 +33,8 @@ class EvaluateReader(Reader): ...@@ -33,7 +33,8 @@ class EvaluateReader(Reader):
This function needs to be implemented by the user, based on data format This function needs to be implemented by the user, based on data format
""" """
features = (line.strip('\n')).split('\t') features = (line.strip('\n')).split('\t')
input_emb = map(float, features[0].split(' ')) input_emb = features[0].split(' ')
input_emb = [float(i) for i in input_emb]
feature_name = ["input_emb"] feature_name = ["input_emb"]
yield zip(feature_name, [input_emb]) yield zip(feature_name, [input_emb])
......
...@@ -33,7 +33,8 @@ class TrainReader(Reader): ...@@ -33,7 +33,8 @@ class TrainReader(Reader):
This function needs to be implemented by the user, based on data format This function needs to be implemented by the user, based on data format
""" """
features = (line.strip('\n')).split('\t') features = (line.strip('\n')).split('\t')
input_emb = map(float, features[0].split(' ')) input_emb = features[0].split(' ')
input_emb = [float(i) for i in input_emb]
item_label = [int(features[1])] item_label = [int(features[1])]
feature_name = ["input_emb", "item_label"] feature_name = ["input_emb", "item_label"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册