未验证 提交 0b25f665 编写于 作者: J JZ-LIANG 提交者: GitHub

[Hybrid parallelism] Tensor Parallel Extra Sync (#50676)

* main code

* unitest bug

* revert cmake
上级 ca2b6095
...@@ -21,6 +21,8 @@ from paddle.distributed import fleet ...@@ -21,6 +21,8 @@ from paddle.distributed import fleet
import paddle import paddle
from . import log_util # noqa: F401 from . import log_util # noqa: F401
from . import hybrid_parallel_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401
from . import tensor_parallel_utils # noqa: F401
__all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
)
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY
from paddle.fluid import core
from paddle.static import Parameter
_supported_optimizer_type = [
"adam",
"adamax",
"adamw",
"decayed_adagrad",
"momentum",
"dgc_momentum",
"lars_momentum",
"merged_momentum",
"lamb",
"sgd",
]
def tensor_parallel_sync_filter_fn(
param, pos_emb=True, layer_norm=True, bias=True
):
"""
Layer fliter function for tensor parallelism transformer.
In tensor parallelism of transformer like model, there is 4 kind of param
that are supposed to be the same in all tensor parallel peers:
* position embedding
* scale of layer norm
* bias of layer norm
* bias of row parallel linear
set corresponding input args to select specific layers.
NOTE adopting the param name pattern for different transformer blocks.
"""
p_name = param.name
if pos_emb and p_name.startswith("pos_embedding"):
return True
elif layer_norm and p_name.endswith("_layer_norm_bias"):
return True
elif layer_norm and p_name.endswith("_layer_norm_scale"):
return True
elif bias and ".b_" in p_name and (param.is_distributed is False):
return True
else:
return False
def resolute_tensor_parallel_ring_id(program):
ops = program.global_block().ops
ring_id = None
for op in ops:
if op.type == "c_identity":
if ring_id is None:
ring_id = int(op.attr("ring_id"))
else:
assert ring_id == int(
op.attr("ring_id")
), "Found two different ring_id for Tensor Parallel: ring_id={} and ring_id={}.".format(
ring_id, int(op.attr("ring_id"))
)
assert ring_id is not None, "Could NOT found ring_id for Tensor Parallel."
return ring_id
def copy_parameters(block_, params):
for param in params:
new_p = Parameter(
block=block_,
shape=param.shape,
dtype=param.dtype,
type=param.type,
lod_level=param.lod_level
if param.type == core.VarDesc.VarType.LOD_TENSOR
else None,
stop_gradient=param.stop_gradient,
trainable=param.trainable,
optimize_attr=param.optimize_attr,
regularizer=param.regularizer,
error_clip=param.error_clip,
name=param.name,
)
assert (
param.is_distributed is False
), "Try to sync Distribted Parameter: {}".format(param)
new_p.is_distributed = False
block_.vars[new_p.name] = new_p
def insert_sync_op(
block, idx, tp_degree, sync_mode, sync_ring_id, src_rank, varname, op_role
):
if sync_mode == "broadcast":
block._insert_op_without_sync(
idx,
type='c_broadcast',
inputs={'X': varname},
outputs={'Out': varname},
attrs={
'ring_id': sync_ring_id,
'root': src_rank,
'use_calc_stream': True,
OP_ROLE_KEY: op_role,
},
)
elif sync_mode == "average":
block._insert_op_without_sync(
idx,
type='scale',
inputs={'X': varname},
outputs={'Out': varname},
attrs={'scale': 1.0 / tp_degree, OP_ROLE_KEY: op_role},
)
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': varname},
outputs={'Out': varname},
attrs={
'ring_id': sync_ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: op_role,
},
)
else:
raise NotImplementedError(
'Sync mode of [{}] is NOT supported.'.format(sync_mode)
)
def insert_synchronization(
block,
params_to_sync,
tp_degree,
sync_ring_id,
sync_param,
sync_grad,
sync_moment,
sync_mode,
src_rank,
):
unsync_param_names = [p.name for p in params_to_sync]
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in _supported_optimizer_type:
assert "Param" in op.input_names
assert len(op.input("Param")) == 1
param_name = op.input("Param")[0]
op_role = op.attr(OP_ROLE_KEY)
if param_name in unsync_param_names:
unsync_param_names.remove(param_name)
# Param sync after opt
if sync_param:
assert (
"ParamOut" in op.output_names
and op.output("ParamOut")[0] == param_name
)
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
param_name,
op_role,
)
if (
"MasterParamOut" in op.output_names
and len(op.output("MasterParamOut")) == 1
):
sync_var = op.output("MasterParamOut")[0]
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
# Moment sync after opt
if sync_moment:
if (
"Moment1Out" in op.output_names
and len(op.output("Moment1Out")) == 1
):
sync_var = op.output("Moment1Out")[0]
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
if (
"Moment2Out" in op.output_names
and len(op.output("Moment2Out")) == 1
):
sync_var = op.output("Moment2Out")[0]
insert_sync_op(
block,
idx + 1,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
# Grad sync before opt
if sync_grad:
assert (
"Grad" in op.input_names and len(op.input("Grad")) == 1
)
sync_var = op.input("Grad")[0]
insert_sync_op(
block,
idx,
tp_degree,
sync_mode,
sync_ring_id,
src_rank,
sync_var,
op_role,
)
assert (
len(unsync_param_names) == 0
), "The following param is unsync by some error: {}".format(
unsync_param_names
)
def add_extra_synchronization(
program,
params_filter_fn=tensor_parallel_sync_filter_fn,
tp_degree=8,
sync_mode="broadcast",
sync_param=True,
sync_grad=False,
sync_moment=False,
src_rank=0,
sync_ring_id=None,
):
"""
Inplace add extra synchronization for input program.
program(Paddle.Program): distributed train program.
params_filter_fn(callable): function to filter out parameter for synchronization.
sync_mode(string): select from
"broadcast": parameter is sync by broadcasted from 'src_rank' to all other ranks.
"average": paramter is sync by average amonge all ranks
src_rank(int): the src used in broadcast sync_mode.
sync_param(bool): extra synchronize parameters.
sync_grad(bool): extra synchronize gradients.
sync_grad(bool): extra synchronize optimizer momentum.
sync_ring_id(int): communicator id use for synchronization, if it is None, use the ring_id of tensor parallel.
"""
logger.info("Constructing Extra Parameter Synchronization.")
logger.info(
"Tensor Parallel Degree: {}, Synchronization mode: {}".format(
tp_degree, sync_mode
)
)
# adopt for pipeline opt
if program._pipeline_opt is not None:
assert (
program._pipeline_opt['section_program'] is not None
), "Pipeline is enable but section_program is None"
program = program._pipeline_opt['section_program']
# step1: collect the param that need to be sync
params_to_sync = []
# TODO support multiple blocks with different parameter.
all_params = program.global_block().all_parameters()
for param in all_params:
if params_filter_fn(param):
params_to_sync.append(param)
logger.info(
"The following param are goning to be synchronization everytime the optimizer update phase of the program is runned: "
)
logger.info([p.name for p in params_to_sync])
# step2: resolute synchronization communicator group (ring_id)
if sync_ring_id is None:
sync_ring_id = resolute_tensor_parallel_ring_id(program)
# step3: insert synchronization
# TODO support gradient merge with different update block
block = program.global_block()
insert_synchronization(
block,
params_to_sync,
tp_degree,
sync_ring_id,
sync_param,
sync_grad,
sync_moment,
sync_mode,
src_rank,
)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import paddle.distributed.fleet as fleet
paddle.enable_static()
class TensorParallelNet(paddle.fluid.dygraph.Layer):
def __init__(self, hidden_size):
super().__init__()
self.embedding = paddle.nn.Embedding(hidden_size, hidden_size)
self.col_linear = fleet.meta_parallel.ColumnParallelLinear(
in_features=hidden_size,
out_features=hidden_size,
weight_attr=None,
has_bias=True,
gather_output=False,
# name="test_column_linear",
)
self.row_linear = fleet.meta_parallel.RowParallelLinear(
in_features=hidden_size,
out_features=hidden_size,
has_bias=True,
input_is_parallel=True,
# name="test_row_linear",
)
self.layer_norm = paddle.nn.LayerNorm(hidden_size)
def forward(self, x):
out = self.embedding(x)
out = self.col_linear(out)
out = self.row_linear(out)
output = self.layer_norm(out)
return output
def filter_fn(param, pos_emb=True, layer_norm=True, bias=True):
"""
Layer fliter function for tensor parallelism transformer.
In tensor parallelism of transformer like model, there is 4 kind of param
that are supposed to be the same in all tensor parallel peers:
* position embedding
* scale of layer norm
* bias of layer norm
* bias of row parallel linear
set corresponding input args to select specific layers.
NOTE adopting the param name pattern for different transformer blocks.
"""
p_name = param.name
if pos_emb and p_name.startswith("embedding"):
return True
elif layer_norm and p_name.startswith("layer_norm"):
return True
elif bias and ".b_" in p_name and (param.is_distributed is False):
return True
else:
return False
class TestFleetMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"
] = "127.0.0.1:36001,127.0.0.1:36002"
def test_tensor_parallel_extra_sync(self):
import paddle.distributed.fleet as fleet
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2}
fleet.init(is_collective=True, strategy=strategy)
main_program, startup_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
hidden_size = 512
input_x = paddle.static.data(
name="x", shape=[-1, hidden_size], dtype='int64'
)
model_a = TensorParallelNet(hidden_size)
y = model_a(input_x)
loss = paddle.mean(y)
optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss)
ref_ops = [
'lookup_table_v2',
'c_identity',
'matmul_v2',
'elementwise_add',
'matmul_v2',
'c_allreduce_sum',
'elementwise_add',
'layer_norm',
'reduce_mean',
'fill_constant',
'reduce_mean_grad',
'layer_norm_grad',
'elementwise_add_grad',
'c_identity',
'matmul_v2_grad',
'elementwise_add_grad',
'matmul_v2_grad',
'c_allreduce_sum',
'lookup_table_v2_grad',
'adam',
'adam',
'adam',
'c_broadcast',
'adam',
'c_broadcast',
'adam',
'c_broadcast',
'adam',
'c_broadcast',
'adam',
]
paddle.distributed.fleet.utils.tensor_parallel_utils.add_extra_synchronization(
main_program, params_filter_fn=filter_fn
)
ops = [op.type for op in main_program.global_block().ops]
self.assertTrue(ops == ref_ops)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册