diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index c17cfc73de7b5767f842701aba62cf9b29ecd156..e5f57ac7cc4c7414567f91be19a900e088c60633 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -188,16 +188,14 @@ class DataParallel(layers.Layer): from ..layers import nn for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] - splited_vars = nn.split( - coalesced_grad, num_or_sections=grad_var_len, dim=0) - reshaped_grad_vars = [] - for g_var, g_shape in zip(splited_vars, grad_shapes): - reshaped_grad_vars.append( - nn.reshape( - x=g_var, shape=g_shape, inplace=True)) - for origin_g_var, reshaped_g_var in zip(origin_grad_vars, - reshaped_grad_vars): - nn.assign(input=reshaped_g_var, output=origin_g_var) + self._helper.main_program.current_block().append_op( + type='split', + inputs={'X': coalesced_grad}, + outputs={'Out': origin_grad_vars}, + attrs={'sections': grad_var_len, + 'axis': 0}) + for g_var, g_shape in zip(origin_grad_vars, grad_shapes): + nn.reshape(x=g_var, shape=g_shape, inplace=True) def apply_collective_grads(self): """