提交 f8d0d84f 编写于 作者: Q qiaolongfei

fix multi card

上级 5d305070
...@@ -220,7 +220,10 @@ def _callback_lookup_(op): ...@@ -220,7 +220,10 @@ def _callback_lookup_(op):
:return: callback function :return: callback function
""" """
if op.type == 'parallel_do' and op.attr('use_nccl'): if op.type == 'parallel_do' and op.attr('use_nccl'):
all_vars = op.block.vars
param_names = set(op.input('parameters')) param_names = set(op.input('parameters'))
param_names = filter(lambda name: all_vars[name].stop_gradient is False,
param_names)
param_grad_names = [n + "@GRAD" for n in param_names] param_grad_names = [n + "@GRAD" for n in param_names]
class ParallelDoCallBack(object): class ParallelDoCallBack(object):
......
...@@ -294,8 +294,7 @@ class ParallelDo(object): ...@@ -294,8 +294,7 @@ class ParallelDo(object):
params = list(set(params)) params = list(set(params))
param_list = [parent_block.var(name) for name in params] return [parent_block.var(name) for name in params]
return filter(lambda param: param.stop_gradient is False, param_list)
def complete_op(self): def complete_op(self):
main_program = self.helper.main_program main_program = self.helper.main_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册