diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 194e4bd66755557a75b10cfb907878f0985e7597..0f09440e4337c969376126b9fcad64b3be0a756f 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -24,12 +24,6 @@ from paddle.distributed.fleet.utils.log_util import logger from paddle.fluid.framework import _global_flags from paddle.fluid.wrapped_decorator import wrap_decorator -protobuf_version = google.protobuf.__version__ -if protobuf_version >= "4.21.0": - from google._upb import _message -else: - from google.protobuf.pyext import _message - __all__ = [] non_auto_func_called = True @@ -2512,10 +2506,19 @@ class DistributedStrategy: self.strategy, f.name + "_configs" ) config_fields = my_configs.DESCRIPTOR.fields + protobuf_version = google.protobuf.__version__ + if protobuf_version >= "4.21.0": + RepeatedScalarContainer = ( + google._upb._message.RepeatedScalarContainer + ) + else: + RepeatedScalarContainer = ( + google.protobuf.pyext._message.RepeatedScalarContainer + ) for ff in config_fields: if isinstance( getattr(my_configs, ff.name), - _message.RepeatedScalarContainer, + RepeatedScalarContainer, ): values = getattr(my_configs, ff.name) for i, v in enumerate(values):