未验证 提交 1460648a 编写于 作者: C chengduo 提交者: GitHub

update parallel.py (#19371)

test=release/1.5
上级 6fbd224e
...@@ -188,16 +188,14 @@ class DataParallel(layers.Layer): ...@@ -188,16 +188,14 @@ class DataParallel(layers.Layer):
from ..layers import nn from ..layers import nn
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
splited_vars = nn.split( self._helper.main_program.current_block().append_op(
coalesced_grad, num_or_sections=grad_var_len, dim=0) type='split',
reshaped_grad_vars = [] inputs={'X': coalesced_grad},
for g_var, g_shape in zip(splited_vars, grad_shapes): outputs={'Out': origin_grad_vars},
reshaped_grad_vars.append( attrs={'sections': grad_var_len,
nn.reshape( 'axis': 0})
x=g_var, shape=g_shape, inplace=True)) for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
for origin_g_var, reshaped_g_var in zip(origin_grad_vars, nn.reshape(x=g_var, shape=g_shape, inplace=True)
reshaped_grad_vars):
nn.assign(input=reshaped_g_var, output=origin_g_var)
def apply_collective_grads(self): def apply_collective_grads(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册