未验证 提交 7f8eb5a2 编写于 作者: W wangguanzhong 提交者: GitHub

fix fleet train (#3527)

上级 51574ea1
...@@ -301,8 +301,7 @@ class Trainer(object): ...@@ -301,8 +301,7 @@ class Trainer(object):
model = self.model model = self.model
if self.cfg.get('fleet', False): if self.cfg.get('fleet', False):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer( self.optimizer = fleet.distributed_optimizer(self.optimizer)
self.optimizer).user_defined_optimizer
elif self._nranks > 1: elif self._nranks > 1:
find_unused_parameters = self.cfg[ find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册