未验证 提交 5199c744 编写于 作者: L lilong12 提交者: GitHub

support weight sharing for pipeline (#35351)

* support weight sharing
上级 18a963a5
......@@ -22,7 +22,7 @@ from collections import defaultdict
import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid.framework import Program, Variable, Parameter, name_scope, default_main_program, default_startup_program, device_guard
from . import framework
from . import layers
......@@ -4234,14 +4234,14 @@ class PipelineOptimizer(object):
self._device = "gpu"
if framework.in_dygraph_mode():
raise Exception("In dygraph, don't support PipelineOptimizer.")
if not isinstance(optimizer, Optimizer) and not isinstance(
optimizer, paddle.optimizer.Optimizer) and not isinstance(
optimizer, paddle.fluid.contrib.mixed_precision.decorator.
OptimizerWithMixedPrecision):
valid_optimizers = (Optimizer, paddle.optimizer.Optimizer,
paddle.fluid.contrib.mixed_precision.decorator.
OptimizerWithMixedPrecision)
if not isinstance(optimizer, valid_optimizers):
raise ValueError("The 'optimizer' parameter for "
"PipelineOptimizer must be an instance of "
"Optimizer, but the given type is {}.".format(
type(optimizer)))
"{}, but the given type is {}.".format(
valid_optimizers, type(optimizer)))
self._optimizer = optimizer
# Get the original optimizer defined by users, such as SGD
......@@ -4774,14 +4774,13 @@ class PipelineOptimizer(object):
# skip data var
if var.is_data: continue
prev_device = None
generate_ops = self.output_var_to_op.get(var_name)
if generate_ops is None:
prev_op = self._find_prev_op(index, var_name)
if prev_op is None:
if var_name not in self._param_device_map:
continue
prev_device = self._param_device_map[var_name]
prev_op = self._find_prev_op(index, var_name)
if not prev_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
......@@ -4928,9 +4927,14 @@ class PipelineOptimizer(object):
self._op_role_key: op_role,
})
extra_index_info['index'] += 1
prefix_name = var.name.split('@')[0]
prefix_var = block.var(prefix_name)
is_param = True if isinstance(prefix_var,
Parameter) else False
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='send_v2' if not use_mp else 'partial_send',
type='send_v2'
if not use_mp or is_param else 'partial_send',
inputs={'X': var},
attrs={
self._op_device_key: prev_dev,
......@@ -4966,7 +4970,8 @@ class PipelineOptimizer(object):
extra_index_info['index'] += 1
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='recv_v2' if not use_mp else 'partial_recv',
type='recv_v2'
if not use_mp or is_param else 'partial_recv',
outputs={'Out': [var]},
attrs={
'out_shape': var_shape,
......@@ -4981,7 +4986,7 @@ class PipelineOptimizer(object):
'id': self.mp_rank,
})
extra_index_info['index'] += 1
if use_mp:
if use_mp and not is_param:
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='partial_allgather',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册