未验证 提交 bd559a24 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of sync_parameters (#33955)

上级 9254183d
...@@ -17,7 +17,10 @@ import six ...@@ -17,7 +17,10 @@ import six
import numpy as np import numpy as np
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
import itertools
import warnings
import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
...@@ -26,9 +29,7 @@ from paddle.fluid.dygraph import to_variable, no_grad ...@@ -26,9 +29,7 @@ from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated from paddle.utils import deprecated
from ..layers import collective from ..layers import collective
from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import base as imperative_base
import warnings from paddle.fluid.framework import ParamBase
import paddle
import itertools
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"] __all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
...@@ -353,8 +354,9 @@ def sync_params_buffers(model, ...@@ -353,8 +354,9 @@ def sync_params_buffers(model,
raise TypeError("The data type of '%s' must be Varbase" % raise TypeError("The data type of '%s' must be Varbase" %
param.name) param.name)
# is_distributed param not need to sync when in mp mode # is_distributed param not need to sync when in mp mode
if is_model_parallel and param.is_distributed: if is_model_parallel and isinstance(param, ParamBase):
continue if param.is_distributed:
continue
model_vars.append(param.detach()) model_vars.append(param.detach())
if len(model_vars) == 0: if len(model_vars) == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册