diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cdad59cabf11a53a106485b2ff20e63d176374ac..5256749c9405eeb72d44966cff1833027744b1d4 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -19,6 +19,7 @@ from ..fluid.framework import Variable from ..fluid.framework import OpProtoHolder from ..fluid.framework import in_dygraph_mode from ..fluid.framework import convert_np_dtype_to_dtype_ +from ..fluid.framework import _varbase_creator from ..fluid.data_feeder import convert_dtype from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.data_feeder import check_type @@ -31,6 +32,7 @@ import paddle from .fleet import fleet import paddle.fluid as fluid import paddle.fluid.core as core +import paddle.fluid.dygraph_utils as dygraph_utils __all__ = [] @@ -158,7 +160,7 @@ def get_group(id=0): """ gm = _get_group_map() - return gm[group] if group in gm else None + return gm[id] if id in gm else None def barrier(group=None): @@ -462,7 +464,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id) else: raise ValueError("Unknown parameter: {}.".format(op)) - return out check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],