From ff2a7b31fd3802bfabae7b3d00bcf2262ebddd40 Mon Sep 17 00:00:00 2001 From: Jiangxinz Date: Tue, 29 Jun 2021 15:02:55 +0800 Subject: [PATCH] fix undef var (#33825) --- python/paddle/distributed/collective.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cdad59cabf1..5256749c940 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -19,6 +19,7 @@ from ..fluid.framework import Variable from ..fluid.framework import OpProtoHolder from ..fluid.framework import in_dygraph_mode from ..fluid.framework import convert_np_dtype_to_dtype_ +from ..fluid.framework import _varbase_creator from ..fluid.data_feeder import convert_dtype from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.data_feeder import check_type @@ -31,6 +32,7 @@ import paddle from .fleet import fleet import paddle.fluid as fluid import paddle.fluid.core as core +import paddle.fluid.dygraph_utils as dygraph_utils __all__ = [] @@ -158,7 +160,7 @@ def get_group(id=0): """ gm = _get_group_map() - return gm[group] if group in gm else None + return gm[id] if id in gm else None def barrier(group=None): @@ -462,7 +464,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id) else: raise ValueError("Unknown parameter: {}.".format(op)) - return out check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], -- GitLab