From 1460648a77b16c15883b5fe87b559fd6f0647bdc Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Mon, 26 Aug 2019 13:54:33 +0800 Subject: [PATCH] update parallel.py (#19371) test=release/1.5 --- python/paddle/fluid/dygraph/parallel.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index c17cfc73de7..e5f57ac7cc4 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -188,16 +188,14 @@ class DataParallel(layers.Layer): from ..layers import nn for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] - splited_vars = nn.split( - coalesced_grad, num_or_sections=grad_var_len, dim=0) - reshaped_grad_vars = [] - for g_var, g_shape in zip(splited_vars, grad_shapes): - reshaped_grad_vars.append( - nn.reshape( - x=g_var, shape=g_shape, inplace=True)) - for origin_g_var, reshaped_g_var in zip(origin_grad_vars, - reshaped_grad_vars): - nn.assign(input=reshaped_g_var, output=origin_g_var) + self._helper.main_program.current_block().append_op( + type='split', + inputs={'X': coalesced_grad}, + outputs={'Out': origin_grad_vars}, + attrs={'sections': grad_var_len, + 'axis': 0}) + for g_var, g_shape in zip(origin_grad_vars, grad_shapes): + nn.reshape(x=g_var, shape=g_shape, inplace=True) def apply_collective_grads(self): """ -- GitLab