未验证 提交 5199c744 编写于 作者: L lilong12 提交者: GitHub

support weight sharing for pipeline (#35351)

* support weight sharing
上级 18a963a5
...@@ -22,7 +22,7 @@ from collections import defaultdict ...@@ -22,7 +22,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard from paddle.fluid.framework import Program, Variable, Parameter, name_scope, default_main_program, default_startup_program, device_guard
from . import framework from . import framework
from . import layers from . import layers
...@@ -4234,14 +4234,14 @@ class PipelineOptimizer(object): ...@@ -4234,14 +4234,14 @@ class PipelineOptimizer(object):
self._device = "gpu" self._device = "gpu"
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
raise Exception("In dygraph, don't support PipelineOptimizer.") raise Exception("In dygraph, don't support PipelineOptimizer.")
if not isinstance(optimizer, Optimizer) and not isinstance( valid_optimizers = (Optimizer, paddle.optimizer.Optimizer,
optimizer, paddle.optimizer.Optimizer) and not isinstance( paddle.fluid.contrib.mixed_precision.decorator.
optimizer, paddle.fluid.contrib.mixed_precision.decorator. OptimizerWithMixedPrecision)
OptimizerWithMixedPrecision): if not isinstance(optimizer, valid_optimizers):
raise ValueError("The 'optimizer' parameter for " raise ValueError("The 'optimizer' parameter for "
"PipelineOptimizer must be an instance of " "PipelineOptimizer must be an instance of "
"Optimizer, but the given type is {}.".format( "{}, but the given type is {}.".format(
type(optimizer))) valid_optimizers, type(optimizer)))
self._optimizer = optimizer self._optimizer = optimizer
# Get the original optimizer defined by users, such as SGD # Get the original optimizer defined by users, such as SGD
...@@ -4774,14 +4774,13 @@ class PipelineOptimizer(object): ...@@ -4774,14 +4774,13 @@ class PipelineOptimizer(object):
# skip data var # skip data var
if var.is_data: continue if var.is_data: continue
prev_device = None prev_device = None
generate_ops = self.output_var_to_op.get(var_name)
if generate_ops is None: prev_op = self._find_prev_op(index, var_name)
if prev_op is None:
if var_name not in self._param_device_map: if var_name not in self._param_device_map:
continue continue
prev_device = self._param_device_map[var_name] prev_device = self._param_device_map[var_name]
prev_op = self._find_prev_op(index, var_name)
if not prev_device: if not prev_device:
prev_device = prev_op.attr(self._op_device_key) \ prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None if prev_op else None
...@@ -4928,9 +4927,14 @@ class PipelineOptimizer(object): ...@@ -4928,9 +4927,14 @@ class PipelineOptimizer(object):
self._op_role_key: op_role, self._op_role_key: op_role,
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
prefix_name = var.name.split('@')[0]
prefix_var = block.var(prefix_name)
is_param = True if isinstance(prefix_var,
Parameter) else False
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='send_v2' if not use_mp else 'partial_send', type='send_v2'
if not use_mp or is_param else 'partial_send',
inputs={'X': var}, inputs={'X': var},
attrs={ attrs={
self._op_device_key: prev_dev, self._op_device_key: prev_dev,
...@@ -4966,7 +4970,8 @@ class PipelineOptimizer(object): ...@@ -4966,7 +4970,8 @@ class PipelineOptimizer(object):
extra_index_info['index'] += 1 extra_index_info['index'] += 1
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='recv_v2' if not use_mp else 'partial_recv', type='recv_v2'
if not use_mp or is_param else 'partial_recv',
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var_shape, 'out_shape': var_shape,
...@@ -4981,7 +4986,7 @@ class PipelineOptimizer(object): ...@@ -4981,7 +4986,7 @@ class PipelineOptimizer(object):
'id': self.mp_rank, 'id': self.mp_rank,
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
if use_mp: if use_mp and not is_param:
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='partial_allgather', type='partial_allgather',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册