未验证 提交 3ba8629c 编写于 作者: W wangguanzhong 提交者: GitHub

fix fleet train (#3528)

上级 e1e02efe
...@@ -256,8 +256,7 @@ class Trainer(object): ...@@ -256,8 +256,7 @@ class Trainer(object):
model = self.model model = self.model
if self.cfg.get('fleet', False): if self.cfg.get('fleet', False):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer( self.optimizer = fleet.distributed_optimizer(self.optimizer)
self.optimizer).user_defined_optimizer
elif self._nranks > 1: elif self._nranks > 1:
find_unused_parameters = self.cfg[ find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册