未验证 提交 9b65d4ce 编写于 作者: L lilong12 提交者: GitHub

bug fix, test=develop (#32752)

上级 97a95526
......@@ -147,7 +147,7 @@ class HybridCommunicateGroup(object):
debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
"mp_degree: %d, pp_degree: %d" % (self.global_rank, self._dp_degree,
self._mp_degree,self._pp_degree)
debug_str += "dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % (
debug_str += ", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % (
self._dp_group, self._mp_group, self._pp_group, self._check_group)
logger.info(debug_str)
......
......@@ -136,9 +136,9 @@ class PipelineParallel(MetaParallelBase):
self._recv_activations(cache_id)
if isinstance(self.caches['inputs'][cache_id], tuple):
inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id])
inputs = tuple(t for t in self.caches['inputs'][cache_id])
else:
inputs = self.caches['inputs'][cache_id].clone()
inputs = self.caches['inputs'][cache_id]
self._clear_grads(inputs)
outputs = self._layers.forward(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册