未验证 提交 a6a2618e 编写于 作者: P parap1uie-s 提交者: GitHub

Fix hAPI bug of not compatible with LayerHook (#47001)

* Fix hAPI bug of not compatible with LayerHook
上级 eb32746a
......@@ -715,11 +715,9 @@ class DynamicGraphAdapter(object):
**self._amp_custom_lists,
level=self._amp_level):
if self._nranks > 1:
outputs = self.ddp_model.forward(
*[to_variable(x) for x in inputs])
outputs = self.ddp_model(*[to_variable(x) for x in inputs])
else:
outputs = self.model.network.forward(
*[to_variable(x) for x in inputs])
outputs = self.model.network(*[to_variable(x) for x in inputs])
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
......@@ -754,7 +752,7 @@ class DynamicGraphAdapter(object):
labels = labels or []
labels = [to_variable(l) for l in to_list(labels)]
outputs = self.model.network.forward(*[to_variable(x) for x in inputs])
outputs = self.model.network(*[to_variable(x) for x in inputs])
# Transfrom data to expected device
expected_device = paddle.device.get_device()
......@@ -809,7 +807,7 @@ class DynamicGraphAdapter(object):
self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)]
self._input_info = _update_input_info(inputs)
outputs = self.model.network.forward(*inputs)
outputs = self.model.network(*inputs)
if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace):
outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册