未验证 提交 988c58e5 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] Add 1F1B Pass (#54260)

* [AutoParallel] add 1F1B

* rm amp
上级 703a64a3
......@@ -339,7 +339,7 @@ class Partitioner:
**{"grad_var_to_var": grad_var_to_var},
)
elif is_optimize_op(op):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must be 2 because of 1F1B PASS
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_opt_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import OrderedDict
from functools import reduce
......@@ -1292,6 +1293,8 @@ class Resharder:
shape_x[0] <= shape_y[0] < shape_x[1]
):
overlapped = True
if shape_x == [0, 0] and shape_y == [0, 0]:
overlapped = True
return overlapped
def is_unshard(self, dims_mapping):
......@@ -1377,6 +1380,14 @@ class Resharder:
# judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh:
is_reshard = True
# not reshard data in send/recv scene
if (
tensor_process_mesh != op_process_mesh
and len(tensor_process_mesh.process_ids)
== len(op_process_mesh.process_ids)
and dist_tensor.serial_tensor.is_data
):
is_reshard = False
else:
op_output_dims_mapping = dist_attr[1]
if all(
......@@ -1432,7 +1443,6 @@ class Resharder:
"""
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
......@@ -1588,6 +1598,11 @@ class Resharder:
Resharder.concat_partitions(
partition_index_list, source_partition_index
)
# TODO(zhaoyingli): Remove the method to a pass.
# Current method to get all pp_ranks' relationship must rely on reshard.
# When reshard insert send/recv pair, the process_group has the pp relationship.
# But the mothod to obtain pp_ranks' relationship is only supported in 'reshard_input',
# casue 'reshard_output' only has current process_group view instead of global view.
if int(op_role) == int(OpRole.Forward):
self.dist_context.up_down_streams.add_pair_stream(
to_send_process, target_process
......@@ -1658,10 +1673,10 @@ class Resharder:
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
min_comm_group = copy.deepcopy(group)
all_partition_index_list_copied = copy.deepcopy(
all_partition_index_list
)
target_partition_index = Resharder.compute_partition_index(
process,
complete_shape,
......@@ -1669,12 +1684,54 @@ class Resharder:
target_process_shape,
target_process_group,
)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
slices_axes.append(idx)
for _process in group:
source_partition_index = (
Resharder.compute_partition_index(
_process,
complete_shape,
source_dims_mapping,
source_process_shape,
source_process_group,
)
)
if not all(
_
for _ in list(
map(
self.is_overlapped,
source_partition_index,
target_partition_index,
)
)
):
min_comm_group.remove(_process)
all_partition_index_list_copied.remove(
source_partition_index
)
concatenated_partition_index_list = []
for partition_index in all_partition_index_list_copied:
Resharder.concat_partitions(
concatenated_partition_index_list, partition_index
)
concatenated_partition_index = (
concatenated_partition_index_list[0]
)
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_starts = []
slice_ends = []
slices_axes = []
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(
target_partition_index[idx][0] - item[0]
)
slice_ends.append(
target_partition_index[idx][1] - item[0]
)
slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
slice_op_desc = SliceOpDesc(
starts=slice_starts,
ends=slice_ends,
......@@ -1703,18 +1760,18 @@ class Resharder:
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
group=min_comm_group,
shape=allgather_shape,
is_bool=(
source_tensor.dtype == paddle.bool
),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
partition_index_list=all_partition_index_list_copied
),
slice_op_desc,
]
if len(group) > 1
if len(min_comm_group) > 1
else [slice_op_desc]
)
......@@ -2420,7 +2477,7 @@ class Resharder:
else:
idx += 1
def _hadnle_recv(self, block, idx, var, op, send_rank, recv_rank):
def _handle_recv(self, block, idx, var, op, send_rank, recv_rank):
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
if var.dtype == paddle.bool:
......@@ -2652,7 +2709,7 @@ class Resharder:
)
elif self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
self._handle_recv(
block,
idx,
var,
......@@ -2684,7 +2741,7 @@ class Resharder:
)
elif self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
self._handle_recv(
block, idx, var, op, item, recv_rank
)
else:
......
......@@ -17,7 +17,14 @@ import os
from paddle.distributed.auto_parallel.static.process_group import (
remove_process_group,
)
from paddle.distributed.auto_parallel.static.utils import (
is_backward_op,
is_forward_op,
is_lr_sched_op,
is_optimize_op,
)
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program
......@@ -32,6 +39,12 @@ __not_shape_var_type__ = [
]
def is_reshard_op(op):
return op.has_attr('op_namescope') and "/auto_parallel/reshard" in op.attr(
'op_namescope'
)
@register_pass("auto_parallel_pipeline")
class PipelinePass(PassBase):
def __init__(self):
......@@ -62,7 +75,8 @@ class PipelinePass(PassBase):
self._cur_pp_stage = self._get_pp_stage(self._cur_rank)
if self._mode == "1F1B":
raise NotImplementedError("1F1B has not been implemented")
self._insert_sync_ops_for_1f1b()
self._task_1f1b()
elif self._mode == "F-Then-B":
raise NotImplementedError("F-Then-B has not been implemented")
elif self._mode == "stream":
......@@ -109,6 +123,98 @@ class PipelinePass(PassBase):
block._sync_with_cpp()
def _insert_sync_ops_for_1f1b(self):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for block in self._program.blocks:
offset = 0
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if is_optimize_op(op):
first_optimize_index = index
break
# insert sync ops
for index, op in enumerate(list(block.ops)):
# NOTE: pipeline might hang when dynamic_shape is True
if op.type in ['send_v2', 'recv_v2']:
op._set_attr("dynamic_shape", False)
# set send op on comm stream
if op.type == 'send_v2':
# step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False)
op_role = op.attr('op_role')
ring_id = op.attr('ring_id')
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._insert_op_without_sync(
index=index + offset,
type="c_sync_calc_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': op_role},
)
offset += 1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if int(op_role) == int(OpRole.Backward):
index = first_optimize_index + offset
new_op_role = OpRole.Optimize
else:
index = index + offset + 1
new_op_role = OpRole.Backward
sync_comm_op = block._insert_op_without_sync(
index=index,
type="c_sync_comm_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
'op_role': new_op_role,
'ring_id': ring_id,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if int(op_role) == int(OpRole.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
offset += 1
block._sync_with_cpp()
offset = 0
backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == "recv_v2" and is_backward_op(op):
backward_recv_index = index
break
if backward_recv_index is None:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
# use nop op for gc
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index:
break
if op.type == 'c_sync_comm_stream' and op.has_attr(
'pipeline_flag'
):
var_name = op.output_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
block._insert_op_without_sync(
index=backward_recv_index,
type="nop",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': OpRole.Backward},
)
block._sync_with_cpp()
def _create_param(self, dst_block, src_var):
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
......@@ -196,6 +302,183 @@ class PipelinePass(PassBase):
break
return pp_idx
def _task_1f1b(self):
# create fwd, bwd, opt program with op_role
num_of_functionality = 4
lr_prog = Program()
fwd_prog = Program()
bwd_prog = Program()
opt_prog = Program()
for idx, src_block in enumerate(self._program.blocks):
if idx == 0:
lr_block = lr_prog.block(0)
fwd_block = fwd_prog.block(0)
bwd_block = bwd_prog.block(0)
opt_block = opt_prog.block(0)
else:
lr_block = lr_prog._create_block(
parent_idx=src_block.parent_idx
)
fwd_block = fwd_prog._create_block(
parent_idx=src_block.parent_idx
)
bwd_block = bwd_prog._create_block(
parent_idx=src_block.parent_idx
)
opt_block = opt_prog._create_block(
parent_idx=src_block.parent_idx
)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
# split the program based on the op_role
for op in src_block.ops:
if is_lr_sched_op(op):
self._create_program(src_block, lr_block, op)
if is_forward_op(op):
self._create_program(src_block, fwd_block, op)
elif is_backward_op(op):
self._create_program(src_block, bwd_block, op)
elif is_optimize_op(op):
self._create_program(src_block, opt_block, op)
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of LRSched, Forward, Backward or Optimizer."
)
lr_prog._sync_with_cpp()
fwd_prog._sync_with_cpp()
bwd_prog._sync_with_cpp()
opt_prog._sync_with_cpp()
lr_prog._rollback()
fwd_prog._rollback()
bwd_prog._rollback()
opt_prog._rollback()
# Create task nodes.
lr_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=lr_prog,
task_id=int(self._cur_rank * num_of_functionality + 0),
node_type="Amplifier",
lazy_initialize=True,
)
lr_task_node.set_run_pre_steps(self._acc_steps)
fwd_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=fwd_prog,
task_id=int(self._cur_rank * num_of_functionality + 1),
node_type="Compute",
lazy_initialize=True,
)
bwd_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=bwd_prog,
task_id=int(self._cur_rank * num_of_functionality + 2),
node_type="Compute",
lazy_initialize=True,
)
opt_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=opt_prog,
task_id=int(self._cur_rank * num_of_functionality + 3),
node_type="Amplifier",
lazy_initialize=True,
)
opt_task_node.set_run_pre_steps(self._acc_steps)
opt_task_node.set_run_at_offset(self._acc_steps - 1)
task_nodes = {
"lr": lr_task_node,
"fwd": fwd_task_node,
"bwd": bwd_task_node,
"opt": opt_task_node,
}
# get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams
pp_upstream_ranks = up_down_streams.ups(self._cur_rank)
pp_downstream_ranks = up_down_streams.downs(self._cur_rank)
# set upstream/downstream for task_nodes of cur_rank
for i, (task_role, task_node) in enumerate(task_nodes.items()):
cur_id = int(self._cur_rank * num_of_functionality + i)
ups = []
downs = []
# set upstream/downstream and buffersize in pipeline stage
pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
prev_id = cur_id - 1
next_id = cur_id + 1
if task_role != "lr":
buf_size = pp_buff_size if task_role == "bwd" else 2
ups.append((prev_id, buf_size))
if task_role != "opt":
buf_size = pp_buff_size if task_role == "fwd" else 2
downs.append((next_id, buf_size))
# set upstream/downstream and buffersize cross pipeline stage
for upstream in pp_upstream_ranks:
upstream_id = int(upstream * num_of_functionality + i)
if task_role == "fwd":
if upstream != -1:
ups.append((upstream_id, 2))
elif task_role == "bwd":
if upstream != -1:
downs.append((upstream_id, 2))
for downstream in pp_downstream_ranks:
downstream_id = int(downstream * num_of_functionality + i)
if task_role == "fwd":
if downstream != -1:
downs.append((downstream_id, 2))
elif task_role == "bwd":
if downstream != -1:
ups.append((downstream_id, 2))
for up in ups:
print(
"Task:",
cur_id,
"'s upstream includes:",
up[0],
", buffer size is:",
up[1],
)
task_node.add_upstream_task(up[0], up[1])
for down in downs:
print(
"Task:",
cur_id,
"'s downstream includes:",
down[0],
", buffer size is:",
down[1],
)
task_node.add_downstream_task(down[0], down[1])
# record global message: task_id_to_rank
task_id_to_rank = {}
for i in range(self._nrank):
for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = {}
self._program._pipeline_opt['fleet_opt'] = {
"tasks": list(task_nodes.values()),
"task_id_to_rank": task_id_to_rank,
"num_micro_batches": self._acc_steps,
}
def _task_stream(self):
num_of_functionality = 5
start_prog = Program()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed import ParallelEnv
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass(use_1f1b=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_1f1b:
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B"
pipeline.accumulate_steps = 2
else:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 2
gradient_merge.avg = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_1f1b=False):
reset_prog()
strategy = apply_pass(use_1f1b)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_1f1b_pass(self):
# navie_pp+gradient_merge training
engine_pp = self.get_engine()
history_pp = engine_pp.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_pp._strategy.pipeline.enable is False
# pp2 1f1b training
engine_1f1b = self.get_engine(True)
history_1f1b = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.enable is True
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_pp = np.array(history_pp.history["loss"])
losses_1f1b = np.array(history_1f1b.history["loss"])
self.check_results(losses_pp, losses_1f1b)
if __name__ == "__main__":
unittest.main()
......@@ -62,6 +62,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_pass_generation_pipeline)
set_tests_properties(test_pass_generation_pipeline
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
......@@ -86,8 +86,14 @@ def generate_model(strategy, dropout_prob=0.0):
modeling._global_parallel_strategy = "mp"
elif strategy == "dp":
modeling._global_parallel_strategy = "dp"
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]),
auto.ProcessMesh(mesh=[1]),
]
else:
raise ValueError("Only support serial, mp2 and dp2.")
raise ValueError("Only support serial, mp2, dp2 and pp2.")
gpt = GPTModel(
vocab_size=1000,
......@@ -105,6 +111,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
pp_degree=2 if strategy == "pp" else None,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
......
......@@ -85,6 +85,9 @@ def train(fetch):
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
# dp optimization config
dp_optimization = dist_strategy.dp_optimization
dp_optimization.enable = True
# sharding config
sharding = dist_strategy.sharding
sharding.enable = True
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class Test1F1BPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -350,32 +350,18 @@ class TransformerDecoder(nn.Layer):
output = tgt
new_caches = []
self.checkpoints = []
if _global_parallel_strategy == "pp":
auto.shard_tensor(
output,
PP_MESH_LIST[0],
[None for i in range(len(output.shape))],
)
if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(
output,
DPPP_MESH_LIST[0],
["x"] + [None for i in range(len(output.shape) - 1)],
)
if _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
output,
MPPP_MESH_LIST[0],
[None for i in range(len(output.shape))],
)
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[0],
["x"] + [None for i in range(len(output.shape) - 1)],
)
for i, mod in enumerate(self.layers):
if _global_parallel_strategy == "pp":
mod = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])
elif _global_parallel_strategy == "dp_pp":
mod = auto.shard_op(mod, DPPP_MESH_LIST[mod.mesh_idx])
elif _global_parallel_strategy == "mp_pp":
mod = auto.shard_op(mod, MPPP_MESH_LIST[mod.mesh_idx])
elif _global_parallel_strategy == "dp_mp_pp":
mod = auto.shard_op(mod, DPMPPP_MESH_LIST[mod.mesh_idx])
if self.use_new_recompute and self.recompute_granularity == "full":
mod = auto.recompute(mod)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册