未验证 提交 1460648a 编写于 作者: C chengduo 提交者: GitHub

update parallel.py (#19371)

test=release/1.5
上级 6fbd224e
......@@ -188,16 +188,14 @@ class DataParallel(layers.Layer):
from ..layers import nn
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
splited_vars = nn.split(
coalesced_grad, num_or_sections=grad_var_len, dim=0)
reshaped_grad_vars = []
for g_var, g_shape in zip(splited_vars, grad_shapes):
reshaped_grad_vars.append(
nn.reshape(
x=g_var, shape=g_shape, inplace=True))
for origin_g_var, reshaped_g_var in zip(origin_grad_vars,
reshaped_grad_vars):
nn.assign(input=reshaped_g_var, output=origin_g_var)
self._helper.main_program.current_block().append_op(
type='split',
inputs={'X': coalesced_grad},
outputs={'Out': origin_grad_vars},
attrs={'sections': grad_var_len,
'axis': 0})
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
nn.reshape(x=g_var, shape=g_shape, inplace=True)
def apply_collective_grads(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册