提交 58a74dc4 编写于 作者: C chengmo

add reader debug

上级 07d791e8
......@@ -52,6 +52,16 @@ class TranspileTrainer(Trainer):
reader = dataloader_instance.dataloader(
reader_class, state, self._config_yaml)
debug_mode = envs.get_global_env("debug_mode", False, namespace)
if debug_mode:
print("--- DataLoader Debug Mode Begin , show pre 10 data ---")
for idx, line in enumerate(reader):
print(line)
if idx >= 9:
break
print("--- DataLoader Debug Mode End , show pre 10 data ---")
exit 0
reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
......@@ -98,6 +108,16 @@ class TranspileTrainer(Trainer):
]
dataset.set_filelist(file_list)
debug_mode = envs.get_global_env("debug_mode", False, namespace)
if debug_mode:
print(
"--- Dataset Debug Mode Begin , show pre 10 data of {}---".format(file_list[0]))
os.system("cat {} | {} | head -10".format(file_list[0], pipe_cmd))
print(
"--- Dataset Debug Mode End , show pre 10 data of {}---".format(file_list[0]))
exit 0
return dataset
def save(self, epoch_id, namespace, is_fleet=False):
......@@ -120,19 +140,24 @@ class TranspileTrainer(Trainer):
# print("save inference model is not supported now.")
# return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
feed_varnames = envs.get_global_env(
"save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env(
"save.inference.fetch_varnames", None, namespace)
if feed_varnames is None or fetch_varnames is None:
return
fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None, namespace)
fetch_vars = [fluid.default_main_program().global_block().vars[varname]
for varname in fetch_varnames]
dirname = envs.get_global_env(
"save.inference.dirname", None, namespace)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_inference_model(self._exe, dirname, feed_varnames, fetch_vars)
fleet.save_inference_model(
self._exe, dirname, feed_varnames, fetch_vars)
else:
fluid.io.save_inference_model(
dirname, feed_varnames, fetch_vars, self._exe)
......
......@@ -24,6 +24,7 @@ train:
batch_size: 2
class: "{workspace}/../criteo_reader.py"
train_data_path: "{workspace}/data/train"
debug_mode: False
model:
models: "{workspace}/model.py"
......
......@@ -21,10 +21,10 @@
<img align="center" src="doc/imgs/structure.png">
<p>
- PaddleRec是源于飞桨生态的搜索推荐模型一站式开箱即用工具,无论您是初学者,开发者,研究者均可便捷的使用PaddleRec完成调研,训练到预测部署的全流程工作。
- PaddleRec提供了搜索推荐任务中语义理解、召回、粗排、精排、多任务学习的全流程解决方案,包含的算法模型均在百度各个业务的实际场景中得到了验证。
- PaddleRec将各个模型及其训练预测流程规范化整理,进行易用性封装,用户只需自定义yaml文件即可快速上手使用。
- PaddleRec以飞桨深度学习框架为核心,融合了大规模分布式训练框架Fleet,以及一键式推理部署框架PaddleServing,支持推荐搜索算法的工业化应用。
- 源于飞桨生态的`搜索推荐模型`**一站式开箱即用工具**
- 适合初学者,开发者,研究者的调研,训练到预测部署的全流程解决方案
- 包含语义理解、召回、粗排、精排、多任务学习、融合等多个任务的推荐搜索算法库
- 自定义`yaml`即可快速上手使用单机训练、大规模分布式训练、离线预测、在线部署
<h2 align="center">PadlleRec概览</h2>
......@@ -37,7 +37,7 @@
<h2 align="center">便捷安装</h2>
### 环境要求
* Python >= 2.7
* Python 2.7/ 3.5 / 3.6 / 3.7
* PaddlePaddle >= 1.7.2
* 操作系统: Windows/Mac/Linux
......@@ -101,24 +101,24 @@ python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e cluster
> 部分表格占位待改(大规模稀疏)
| 方向 | 模型 | 单机CPU训练 | 单机GPU训练 | 分布式CPU训练 | 大规模稀疏 | 分布式GPU训练 | 自定义数据集 |
| :------: | :----------------------------------------------------------------------------: | :---------: | :---------: | :-----------: | :--------: | :-----------: | :----------: |
| 内容理解 | [Text-Classifcation](models/contentunderstanding/text_classification/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 内容理解 | [TagSpace](models/contentunderstanding/tagspace/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [Word2Vec](models/recall/word2vec/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [TDM](models/recall/tdm/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [SSR](models/recall/ssr/model.py) | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 召回 | [Gru4Rec](models/recall/gru4rec/model.py) | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 排序 | [CTR-Dnn](models/rank/dnn/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [DeepFm](models/rank/deepfm/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [xDeepFm](models/rank/xdeepfm/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [DIN](models/rank/din/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [Wide&Deep](models/rank/wide_deep/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 多任务 | [ESMM](models/multitask/essm/model.py) | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 多任务 | [MMOE](models/multitask/mmoe/model.py) | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 排序 | [ShareBottom](models/multitask/share-bottom/model.py) | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 匹配 | [DSSM](models/match/dssm/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 匹配 | [Simnet](models/match/multiview-simnet/model.py) | ✓ | x | ✓ | x | ✓ | ✓ |
| 方向 | 模型 | 单机CPU训练 | 单机GPU训练 | 分布式CPU训练 | 分布式GPU训练 |
| :------: | :----------------------------------------------------------------------------: | :---------: | :---------: | :-----------: | :-----------: |
| 内容理解 | [Text-Classifcation](models/contentunderstanding/text_classification/model.py) | ✓ | x | ✓ | x |
| 内容理解 | [TagSpace](models/contentunderstanding/tagspace/model.py) | ✓ | x | ✓ | x |
| 召回 | [Word2Vec](models/recall/word2vec/model.py) | ✓ | x | ✓ | x |
| 召回 | [TDM](models/recall/tdm/model.py) | ✓ | x | ✓ | x |
| 召回 | [SSR](models/recall/ssr/model.py) | ✓ | ✓ | ✓ | x |
| 召回 | [Gru4Rec](models/recall/gru4rec/model.py) | ✓ | ✓ | ✓ | x |
| 排序 | [CTR-Dnn](models/rank/dnn/model.py) | ✓ | x | ✓ | x |
| 排序 | [DeepFm](models/rank/deepfm/model.py) | ✓ | x | ✓ | x |
| 排序 | [xDeepFm](models/rank/xdeepfm/model.py) | ✓ | x | ✓ | x |
| 排序 | [DIN](models/rank/din/model.py) | ✓ | x | ✓ | x |
| 排序 | [Wide&Deep](models/rank/wide_deep/model.py) | ✓ | x | ✓ | x |
| 多任务 | [ESMM](models/multitask/essm/model.py) | ✓ | ✓ | ✓ | x |
| 多任务 | [MMOE](models/multitask/mmoe/model.py) | ✓ | ✓ | ✓ | x |
| 排序 | [ShareBottom](models/multitask/share-bottom/model.py) | ✓ | ✓ | ✓ | x |
| 匹配 | [DSSM](models/match/dssm/model.py) | ✓ | x | ✓ | x |
| 匹配 | [Simnet](models/match/multiview-simnet/model.py) | ✓ | x | ✓ | x |
......
......@@ -152,7 +152,8 @@ def cluster_engine(args):
cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
print("launch {} engine with cluster to with model: {}".format(trainer, args.model))
print("launch {} engine with cluster to with model: {}".format(
trainer, args.model))
set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model)
......@@ -245,9 +246,11 @@ if __name__ == "__main__":
choices=["single", "local_cluster", "cluster",
"tdm_single", "tdm_local_cluster", "tdm_cluster"])
parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu")
parser.add_argument("-d", "--device", type=str,
choices=["cpu", "gpu"], default="cpu")
parser.add_argument("-b", "--backend", type=str, default=None)
parser.add_argument("-r", "--role", type=str, choices=["master", "worker"], default="master")
parser.add_argument("-r", "--role", type=str,
choices=["master", "worker"], default="master")
abs_dir = os.path.dirname(os.path.abspath(__file__))
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册