未验证 提交 d3b98f0d 编写于 作者: L LielinJiang 提交者: GitHub

Fix dynamic parallel train mode for hapi (#27787)

* fix dynamic parallel for hapi

* fix code style
上级 b53970ee
...@@ -638,19 +638,14 @@ class DynamicGraphAdapter(object): ...@@ -638,19 +638,14 @@ class DynamicGraphAdapter(object):
if self._nranks > 1: if self._nranks > 1:
outputs = self.ddp_model.forward(* [to_variable(x) for x in inputs]) outputs = self.ddp_model.forward(* [to_variable(x) for x in inputs])
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
final_loss = fluid.layers.sum(losses)
final_loss = self.ddp_model.scale_loss(final_loss)
final_loss.backward()
self.ddp_model.apply_collective_grads()
else: else:
outputs = self.model.network.forward( outputs = self.model.network.forward(
* [to_variable(x) for x in inputs]) * [to_variable(x) for x in inputs])
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses) losses = self.model._loss(*(to_list(outputs) + labels))
final_loss = fluid.layers.sum(losses) losses = to_list(losses)
final_loss.backward() final_loss = fluid.layers.sum(losses)
final_loss.backward()
self.model._optimizer.minimize(final_loss) self.model._optimizer.minimize(final_loss)
self.model.network.clear_gradients() self.model.network.clear_gradients()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册