未验证 提交 bd559a24 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of sync_parameters (#33955)

上级 9254183d
......@@ -17,7 +17,10 @@ import six
import numpy as np
import warnings
from collections import OrderedDict
import itertools
import warnings
import paddle
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
......@@ -26,9 +29,7 @@ from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
from ..layers import collective
from paddle.fluid.dygraph import base as imperative_base
import warnings
import paddle
import itertools
from paddle.fluid.framework import ParamBase
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
......@@ -353,8 +354,9 @@ def sync_params_buffers(model,
raise TypeError("The data type of '%s' must be Varbase" %
param.name)
# is_distributed param not need to sync when in mp mode
if is_model_parallel and param.is_distributed:
continue
if is_model_parallel and isinstance(param, ParamBase):
if param.is_distributed:
continue
model_vars.append(param.detach())
if len(model_vars) == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册