超大规模稀疏参数(DistributedLookupTable)的本地预测+增量使用
Created by: seiriosPlus
涉及到超大规模稀疏参数使用的模型,在保存参数的时候(save_inference_model/save_persistables)会和一般模型不太一样, 其中超大规模稀疏参数会被直接保存在PSERVER端(每个PSERVER端均有一份且只包含当前PSERVER上的参数),其他参数还是正常保存于TRAINER端。
涉及到的相关代码位于:
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/utils/lookup_table_utils.py
- 本地预测 本地预测说明: 本地预测是指,将带有超大规模稀疏参数的多机训练的参数在本地做合并,用一个节点来做预测的方式。如果拥有大内存(能够存放得下超大规模稀疏参数)的服务器,用本地预测的方式能够快速的对训练效果进行验证。
本地预测的一般步骤:
-
将训练过程中得到的模型参数(包括TRAINER端的参数和所有PSERVER端的参数)做合并。
-
组建inference的网络,需要包含加入Optimizer+minimize
-
加入Distributed Transpiler的信息,构建多机的网络
-
t.get_trainer_program(wait_port=False)获得TRAINER端的program
-
通过lookup_table_utils.convert_dist_to_sparse_program(main_program)将超大规模稀疏参数转换为本地的稀疏参数
-
加入输入数据的FEED和FETCH(参考save_inference_model),调用lookup_table_utils.get_inference_model(inference_program, feeds, fetch_targets)保存一个用于本地预测的网络
-
调用lookup_table_utils.load_persistables_for_inference(executor=exe, dirname=model_dir, program=inference_program, lookup_table_var_name='dis_emb'),进行本地的模型和参数加载,需要显示的指定超大规模稀疏参数的参数名
-
多机预测 略
-
多机增量训练 略