未验证 提交 21d7b786 编写于 作者: C chengduo 提交者: GitHub

update parallel_helper (#17691)

test=develop
上级 7a401da5
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from ..layers import collective from ..layers import collective
from ..framework import Parameter
__parallel_ctx__clz__ = None __parallel_ctx__clz__ = None
...@@ -39,5 +39,5 @@ def _init_parallel_ctx(): ...@@ -39,5 +39,5 @@ def _init_parallel_ctx():
def _broadcast_parameters(parameters): def _broadcast_parameters(parameters):
for param in parameters: for param in parameters:
if param.trainable: if isinstance(param, Parameter) and param.trainable:
collective._broadcast(param, 0, sync_mode=True) collective._broadcast(param, 0, sync_mode=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册