提交 82571360 编写于 作者: D dongdaxiang

fix bug

test=develop
上级 a659b37a
......@@ -78,7 +78,7 @@ class MultiTrainer(TrainerDesc):
def _gen_trainer_desc(self):
super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self._device_worker._set_infer(self.infer_)
self._device_worker._set_infer(self._infer)
self._device_worker._gen_worker_desc(self.proto_desc)
......@@ -96,6 +96,6 @@ class DistMultiTrainer(TrainerDesc):
self.proto_desc.class_name = "DistMultiTrainer"
if self._program == None:
raise RuntimeError("None Program")
self._device_worker._set_infer(self.infer_)
self._device_worker._set_program(self.program_)
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册