提交 33fc1750 编写于 作者: C chengmo

add reader debug mode

上级 013c12ff
......@@ -42,7 +42,7 @@ class TranspileTrainer(Trainer):
namespace = "train.reader"
class_name = "TrainReader"
else:
dataloader = self.model._infer_data_loader
readerdataloader = self.model._infer_data_loader
namespace = "evaluate.reader"
class_name = "EvaluateReader"
......@@ -52,22 +52,22 @@ class TranspileTrainer(Trainer):
reader = dataloader_instance.dataloader(
reader_class, state, self._config_yaml)
debug_mode = envs.get_global_env("debug_mode", False, namespace)
reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
else:
dataloader.set_sample_generator(reader, batch_size)
debug_mode = envs.get_global_env("reader_debug_mode", False, namespace)
if debug_mode:
print("--- DataLoader Debug Mode Begin , show pre 10 data ---")
for idx, line in enumerate(reader):
for idx, line in enumerate(reader()):
print(line)
if idx >= 9:
break
print("--- DataLoader Debug Mode End , show pre 10 data ---")
exit(0)
reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
else:
dataloader.set_sample_generator(reader, batch_size)
return dataloader
def _get_dataset(self, state="TRAIN"):
......@@ -109,7 +109,7 @@ class TranspileTrainer(Trainer):
dataset.set_filelist(file_list)
debug_mode = envs.get_global_env("debug_mode", False, namespace)
debug_mode = envs.get_global_env("reader_debug_mode", False, namespace)
if debug_mode:
print(
"--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0]))
......
......@@ -24,7 +24,7 @@ train:
batch_size: 2
class: "{workspace}/../criteo_reader.py"
train_data_path: "{workspace}/data/train"
debug_mode: False
reader_debug_mode: False
model:
models: "{workspace}/model.py"
......
......@@ -25,6 +25,7 @@ train:
class: "{workspace}/tdm_reader.py"
train_data_path: "{workspace}/data/train"
test_data_path: "{workspace}/data/test"
reader_debug_mode: False
model:
models: "{workspace}/model.py"
......
......@@ -33,7 +33,8 @@ class EvaluateReader(Reader):
This function needs to be implemented by the user, based on data format
"""
features = (line.strip('\n')).split('\t')
input_emb = map(float, features[0].split(' '))
input_emb = features[0].split(' ')
input_emb = [float(i) for i in input_emb]
feature_name = ["input_emb"]
yield zip(feature_name, [input_emb])
......
......@@ -33,7 +33,8 @@ class TrainReader(Reader):
This function needs to be implemented by the user, based on data format
"""
features = (line.strip('\n')).split('\t')
input_emb = map(float, features[0].split(' '))
input_emb = features[0].split(' ')
input_emb = [float(i) for i in input_emb]
item_label = [int(features[1])]
feature_name = ["input_emb", "item_label"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册