diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 8f38ba447fcb3d59a2c609dacff7c921f01935fd..470a4d83aac3fedb2135d44567fe31688894b093 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -147,7 +147,7 @@ class HybridCommunicateGroup(object): debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \ "mp_degree: %d, pp_degree: %d" % (self.global_rank, self._dp_degree, self._mp_degree,self._pp_degree) - debug_str += "dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % ( + debug_str += ", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % ( self._dp_group, self._mp_group, self._pp_group, self._check_group) logger.info(debug_str) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 11180054afbfc4730917600a8bd65ea99ec7c3ec..8fb29a4485df068a3c7c1623306c4706f2f45287 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -125,9 +125,9 @@ class PipelineParallel(MetaParallelBase): self._recv_activations(cache_id) if isinstance(self.caches['inputs'][cache_id], tuple): - inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id]) + inputs = tuple(t for t in self.caches['inputs'][cache_id]) else: - inputs = self.caches['inputs'][cache_id].clone() + inputs = self.caches['inputs'][cache_id] self._clear_grads(inputs) outputs = self._layers.forward(inputs)