diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 2be062962ec9d33493bbcfc9a6c06239f2d4dee7..a905e1dba846754de487b360d54c75a988596a14 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -17,7 +17,10 @@ import six import numpy as np import warnings from collections import OrderedDict +import itertools +import warnings +import paddle from paddle.fluid import core from paddle.fluid import framework from paddle.fluid.dygraph import layers @@ -26,9 +29,7 @@ from paddle.fluid.dygraph import to_variable, no_grad from paddle.utils import deprecated from ..layers import collective from paddle.fluid.dygraph import base as imperative_base -import warnings -import paddle -import itertools +from paddle.fluid.framework import ParamBase __all__ = ["prepare_context", "ParallelEnv", "DataParallel"] @@ -353,8 +354,9 @@ def sync_params_buffers(model, raise TypeError("The data type of '%s' must be Varbase" % param.name) # is_distributed param not need to sync when in mp mode - if is_model_parallel and param.is_distributed: - continue + if is_model_parallel and isinstance(param, ParamBase): + if param.is_distributed: + continue model_vars.append(param.detach()) if len(model_vars) == 0: