未验证 提交 4bacf2ab 编写于 作者: J JZ-LIANG 提交者: GitHub

Extra Sync for Tensor Parallel (#50637)

上级 9025fddd
...@@ -21,22 +21,28 @@ from paddle.distributed import fleet ...@@ -21,22 +21,28 @@ from paddle.distributed import fleet
import paddle import paddle
from . import log_util # noqa: F401 from . import log_util # noqa: F401
from . import hybrid_parallel_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401
from . import tensor_parallel_utils # noqa: F401
__all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
def recompute(function, *args, **kwargs): def recompute(function, *args, **kwargs):
""" """
recompute intermediate activations to save then memory. recompute intermediate activations to save the memory.
Parameters: Parameters:
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation. in backward stage for gradient calculation.
*args(Tensor): inputs to the function. *args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params,
indicate whether to save the forward rng. If it is True, then the last forward rng value will be and the other contains ``preserve_rng_state`` and ``use_reentrant``. the key-value pair of ``preserve_rng_state``,
restored when the forward recalculation of backpropagation is performed. The default which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value
preserve_rng_state is True. will be restored when the forward recalculation of backpropagation is performed, its default value is True.
the key-value pair of ``use_reentrant`` is used to indicate which implementation of recompute you will be used.
``use_reentrant=True`` means to use the PyLayer implementation of recompute, ``use_reentrant=False`` means to
use the Hook implementation of recompute, its default value is True.
Returns: Returns:
Output of function on args. Output of function on args.
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
)
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY
from paddle.fluid import core
from paddle.fluid.framework import Parameter
_supported_optimizer_type = [
"adam",
"adamax",
"adamw",
"decayed_adagrad",
"momentum",
"dgc_momentum",
"lars_momentum",
"merged_momentum",
"lamb",
"sgd",
]
def tensor_parallel_sync_filter_fn(
param, pos_emb=True, layer_norm=True, bias=True
):
"""
Layer fliter function for tensor parallelism transformer.
In tensor parallelism of transformer like model, there is 4 kind of param
that are supposed to be the same in all tensor parallel peers:
* position embedding
* scale of layer norm
* bias of layer norm
* bias of row parallel linear
set corresponding input args to select specific layers.
NOTE adopting the param name pattern for different transformer blocks.
"""
p_name = param.name
if pos_emb and p_name.startswith("pos_embedding"):
return True
elif layer_norm and p_name.endswith("_layer_norm_bias"):
return True
elif layer_norm and p_name.endswith("_layer_norm_scale"):
return True
elif bias and ".b_" in p_name and (param.is_distributed is False):
return True
else:
return False
def resolute_tensor_parallel_ring_id(program):
ops = program.global_block().ops
ring_id = None
for op in ops:
if op.type == "c_identity":
if ring_id is None:
ring_id = int(op.attr("ring_id"))
else:
assert ring_id == int(
op.attr("ring_id")
), "Found two different ring_id for Tensor Parallel: ring_id={} and ring_id={}.".format(
ring_id, int(op.attr("ring_id"))
)
assert ring_id is not None, "Could NOT found ring_id for Tensor Parallel."
return ring_id
def copy_parameters(block_, params):
for param in params:
new_p = Parameter(
block=block_,
shape=param.shape,
dtype=param.dtype,
type=param.type,
lod_level=param.lod_level
if param.type == core.VarDesc.VarType.LOD_TENSOR
else None,
stop_gradient=param.stop_gradient,
trainable=param.trainable,
optimize_attr=param.optimize_attr,
regularizer=param.regularizer,
error_clip=param.error_clip,
name=param.name,
)
assert (
param.is_distributed is False
), "Try to sync Distribted Parameter: {}".format(param)
new_p.is_distributed = False
block_.vars[new_p.name] = new_p
def insert_sync_op(
block, idx, tp_degree, sync_mode, sync_ring_id, src_rank, varname, op_role
):
if sync_mode == "broadcast":
block._insert_op_without_sync(
idx,
type='c_broadcast',
inputs={'X': varname},
outputs={'Out': varname},
attrs={
'ring_id': sync_ring_id,
'root': src_rank,
'use_calc_stream': True,
OP_ROLE_KEY: op_role,
},
)
elif sync_mode == "average":
block._insert_op_without_sync(
idx,
type='scale',
inputs={'X': varname},
outputs={'Out': varname},
attrs={'scale': 1.0 / tp_degree, OP_ROLE_KEY: op_role},
)
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': varname},
outputs={'Out': varname},
attrs={
'ring_id': sync_ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: op_role,
},
)
else:
raise NotImplementedError(
'Sync mode of [{}] is NOT supported.'.format(sync_mode)
)
def insert_synchronization(
block,
params_to_sync,
tp_degree,
sync_ring_id,
sync_param,
sync_grad,
sync_moment,
sync_mode,
src_rank,
):
unsync_param_names = [p.name for p in params_to_sync]
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in _supported_optimizer_type:
assert "Param" in op.input_names
assert len(op.input("Param")) == 1
param_name = op.input("Param")[0]
op_role = op.attr(OP_ROLE_KEY)
if param_name in unsync_param_names:
unsync_param_names.remove(param_name)
# Param sync after opt
if sync_param:
assert (
"ParamOut" in op.output_names
and op.output("ParamOut")[0] == param_name
)
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
param_name,
op_role,
)
if (
"MasterParamOut" in op.output_names
and len(op.output("MasterParamOut")) == 1
):
sync_var = op.output("MasterParamOut")[0]
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
# Moment sync after opt
if sync_moment:
if (
"Moment1Out" in op.output_names
and len(op.output("Moment1Out")) == 1
):
sync_var = op.output("Moment1Out")[0]
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
if (
"Moment2Out" in op.output_names
and len(op.output("Moment2Out")) == 1
):
sync_var = op.output("Moment2Out")[0]
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
# Grad sync before opt
if sync_grad:
assert (
"Grad" in op.input_names and len(op.input("Grad")) == 1
)
sync_var = op.input("Grad")[0]
insert_sync_op(
block,
idx,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
assert (
len(unsync_param_names) == 0
), "The following param is unsync by some error: {}".format(
unsync_param_names
)
def add_extra_synchronization(
program,
params_filter_fn=tensor_parallel_sync_filter_fn,
tp_degree=8,
sync_mode="broadcast",
sync_param=True,
sync_grad=False,
sync_moment=False,
src_rank=0,
sync_ring_id=None,
):
"""
Inplace add extra synchronization for input program.
program(Paddle.Program): distributed train program.
params_filter_fn(callable): function to filter out parameter for synchronization.
sync_mode(string): select from
"broadcast": parameter is sync by broadcasted from 'src_rank' to all other ranks.
"average": paramter is sync by average amonge all ranks
src_rank(int): the src used in broadcast sync_mode.
sync_param(bool): extra synchronize parameters.
sync_grad(bool): extra synchronize gradients.
sync_grad(bool): extra synchronize optimizer momentum.
sync_ring_id(int): communicator id use for synchronization, if it is None, use the ring_id of tensor parallel.
"""
logger.info("Constructing Extra Parameter Synchronization.")
logger.info(
"Tensor Parallel Degree: {}, Synchronization mode: {}".format(
tp_degree, sync_mode
)
)
# adopt for pipeline opt
if program._pipeline_opt is not None:
assert (
program._pipeline_opt['section_program'] is not None
), "Pipeline is enable but section_program is None"
program = program._pipeline_opt['section_program']
# step1: collect the param that need to be sync
params_to_sync = []
# TODO support multiple blocks with different parameter.
all_params = program.global_block().all_parameters()
for param in all_params:
if params_filter_fn(param):
params_to_sync.append(param)
logger.info(
"The following param are goning to be synchronization everytime the optimizer update phase of the program is runned: "
)
logger.info([p.name for p in params_to_sync])
# step2: resolute synchronization communicator group (ring_id)
if sync_ring_id is None:
sync_ring_id = resolute_tensor_parallel_ring_id(program)
# step3: insert synchronization
# TODO support gradient merge with different update block
block = program.global_block()
insert_synchronization(
block,
params_to_sync,
tp_degree,
sync_ring_id,
sync_param,
sync_grad,
sync_moment,
sync_mode,
src_rank,
)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import paddle.distributed.fleet as fleet
paddle.enable_static()
class TensorParallelNet(paddle.fluid.dygraph.Layer):
def __init__(self, hidden_size):
super().__init__()
self.embedding = paddle.nn.Embedding(hidden_size, hidden_size)
self.col_linear = fleet.meta_parallel.ColumnParallelLinear(
in_features=hidden_size,
out_features=hidden_size,
weight_attr=None,
has_bias=True,
gather_output=False,
# name="test_column_linear",
)
self.row_linear = fleet.meta_parallel.RowParallelLinear(
in_features=hidden_size,
out_features=hidden_size,
has_bias=True,
input_is_parallel=True,
# name="test_row_linear",
)
self.layer_norm = paddle.nn.LayerNorm(hidden_size)
def forward(self, x):
out = self.embedding(x)
out = self.col_linear(out)
out = self.row_linear(out)
output = self.layer_norm(out)
return output
def filter_fn(param, pos_emb=True, layer_norm=True, bias=True):
"""
Layer fliter function for tensor parallelism transformer.
In tensor parallelism of transformer like model, there is 4 kind of param
that are supposed to be the same in all tensor parallel peers:
* position embedding
* scale of layer norm
* bias of layer norm
* bias of row parallel linear
set corresponding input args to select specific layers.
NOTE adopting the param name pattern for different transformer blocks.
"""
p_name = param.name
if pos_emb and p_name.startswith("embedding"):
return True
elif layer_norm and p_name.startswith("layer_norm"):
return True
elif bias and ".b_" in p_name and (param.is_distributed is False):
return True
else:
return False
class TestFleetMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"
] = "127.0.0.1:36001,127.0.0.1:36002"
def test_tensor_parallel_extra_sync(self):
import paddle.distributed.fleet as fleet
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2}
fleet.init(is_collective=True, strategy=strategy)
main_program, startup_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
hidden_size = 512
input_x = paddle.static.data(
name="x", shape=[-1, hidden_size], dtype='int64'
)
model_a = TensorParallelNet(hidden_size)
y = model_a(input_x)
loss = paddle.mean(y)
optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss)
ref_ops = [
'lookup_table_v2',
'c_identity',
'matmul_v2',
'elementwise_add',
'matmul_v2',
'c_allreduce_sum',
'elementwise_add',
'layer_norm',
'reduce_mean',
'fill_constant',
'reduce_mean_grad',
'layer_norm_grad',
'elementwise_add_grad',
'c_identity',
'matmul_v2_grad',
'elementwise_add_grad',
'matmul_v2_grad',
'c_allreduce_sum',
'lookup_table_v2_grad',
'adam',
'adam',
'adam',
'c_broadcast',
'adam',
'c_broadcast',
'adam',
'c_broadcast',
'adam',
'c_broadcast',
'adam',
]
paddle.distributed.fleet.utils.tensor_parallel_utils.add_extra_synchronization(
main_program, params_filter_fn=filter_fn
)
ops = [op.type for op in main_program.global_block().ops]
self.assertTrue(ops == ref_ops)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册