From 5199c7446bdb1e079479b737a2a4f7703b417660 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Wed, 8 Sep 2021 10:47:19 +0800 Subject: [PATCH] support weight sharing for pipeline (#35351) * support weight sharing --- python/paddle/fluid/optimizer.py | 33 ++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index fc5c30684b2..8b2af328f52 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -22,7 +22,7 @@ from collections import defaultdict import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table -from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard +from paddle.fluid.framework import Program, Variable, Parameter, name_scope, default_main_program, default_startup_program, device_guard from . import framework from . import layers @@ -4234,14 +4234,14 @@ class PipelineOptimizer(object): self._device = "gpu" if framework.in_dygraph_mode(): raise Exception("In dygraph, don't support PipelineOptimizer.") - if not isinstance(optimizer, Optimizer) and not isinstance( - optimizer, paddle.optimizer.Optimizer) and not isinstance( - optimizer, paddle.fluid.contrib.mixed_precision.decorator. - OptimizerWithMixedPrecision): + valid_optimizers = (Optimizer, paddle.optimizer.Optimizer, + paddle.fluid.contrib.mixed_precision.decorator. + OptimizerWithMixedPrecision) + if not isinstance(optimizer, valid_optimizers): raise ValueError("The 'optimizer' parameter for " "PipelineOptimizer must be an instance of " - "Optimizer, but the given type is {}.".format( - type(optimizer))) + "{}, but the given type is {}.".format( + valid_optimizers, type(optimizer))) self._optimizer = optimizer # Get the original optimizer defined by users, such as SGD @@ -4774,14 +4774,13 @@ class PipelineOptimizer(object): # skip data var if var.is_data: continue prev_device = None - generate_ops = self.output_var_to_op.get(var_name) - if generate_ops is None: + + prev_op = self._find_prev_op(index, var_name) + if prev_op is None: if var_name not in self._param_device_map: continue prev_device = self._param_device_map[var_name] - prev_op = self._find_prev_op(index, var_name) - if not prev_device: prev_device = prev_op.attr(self._op_device_key) \ if prev_op else None @@ -4928,9 +4927,14 @@ class PipelineOptimizer(object): self._op_role_key: op_role, }) extra_index_info['index'] += 1 + prefix_name = var.name.split('@')[0] + prefix_var = block.var(prefix_name) + is_param = True if isinstance(prefix_var, + Parameter) else False block._insert_op_without_sync( index=index + extra_index_info['index'], - type='send_v2' if not use_mp else 'partial_send', + type='send_v2' + if not use_mp or is_param else 'partial_send', inputs={'X': var}, attrs={ self._op_device_key: prev_dev, @@ -4966,7 +4970,8 @@ class PipelineOptimizer(object): extra_index_info['index'] += 1 block._insert_op_without_sync( index=index + extra_index_info['index'], - type='recv_v2' if not use_mp else 'partial_recv', + type='recv_v2' + if not use_mp or is_param else 'partial_recv', outputs={'Out': [var]}, attrs={ 'out_shape': var_shape, @@ -4981,7 +4986,7 @@ class PipelineOptimizer(object): 'id': self.mp_rank, }) extra_index_info['index'] += 1 - if use_mp: + if use_mp and not is_param: block._insert_op_without_sync( index=index + extra_index_info['index'], type='partial_allgather', -- GitLab