并行模式读取非并行训练模型参数的问题
Created by: mottled233
您好,我在训练时采用了单卡gpu训练,训练结束后调用 io.save_persistables保存参数。 在预测阶段读取模型时,我计划采用多卡并行,使用 CompiledProgram(self.predict_program). with_data_parallel 进行并行处理,但在读取参数时出现问题(读取使用 io.load_vars)。
我的程序中,如果先读取参数,之后CompiledProgram,会报错 Error: Cannot find fetched variable(qas_ids).(Perhaps the main_program is not set to ParallelExecutor) at (/paddle/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc:145)
如果在先CompiledProgram,后读取参数,会报错 TypeError: program's type should be Program
请问怎么解决?