未验证 提交 e64823c1 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Use a fast completion for data parallelism (#43585)

* [Auto Parallel] Use a fast completion for data parallelism

* remove unuse cuSparse function

* [Auto Parallel] Fix some bugs of the fast dp completion

* [Auto Parallel] Add the cmake statements

* [Auto Parallel] Make the unittest adapt to the new interface

* [Auto Parallel] Modify the timeout of the unittest

* [Auto Parallel] Remove unnecessary comments
Co-authored-by: Nzhouwei25 <zhouwei25@baidu.com>
上级 74259bac
......@@ -27,6 +27,7 @@ from .dist_op import DistributedOperator
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .process_mesh import ProcessMesh
from .process_group import get_world_process_group
from paddle.distributed.fleet.meta_optimizers.common import OpRole
......@@ -765,16 +766,29 @@ class Completer:
else:
self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize()
start_time = time.time()
# print("start time", start_time, flush=True)
if not self._dist_context.data_parallel:
self._dist_context.initialize(with_graph=True)
# self._dist_context.validate_dist_attr_for_program()
self._prepare()
self._update_process_mesh()
self._prepare()
self._update_dims_mapping()
self._update_process_mesh()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
else:
self._dist_context.initialize(with_graph=False)
self._update_dims_mapping()
# A fast and special completion for data parallel
self._update_dist_attr_for_dp()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
# print_program_with_dist_attr(self._dist_context.serial_main_program,
# self._dist_context)
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self._complete_high_order_grad_annotation(serial_main_program)
......@@ -784,8 +798,107 @@ class Completer:
self._dist_context.validate_dist_attr_for_program()
end_time = time.time()
# print("end time", end_time, flush=True)
# print("elapsed time", end_time - start_time, flush=True)
return serial_main_program
def _update_dist_attr_for_dp(self):
# TODO: we must ensure the world process group contains all ranks
ranks = get_world_process_group().ranks
process_mesh = ProcessMesh(ranks)
for dist_tensor in self._dist_context._dist_tensors_for_program.values(
):
serial_tensor = dist_tensor.serial_tensor
tensor_dist_attr = dist_tensor.dist_attr
tensor_dist_attr.process_mesh = process_mesh
for dist_op in self._dist_context._dist_ops_for_program.values():
serial_op = dist_op.serial_op
op_desc = serial_op.desc
op_dist_attr = dist_op.dist_attr
op_dist_attr.process_mesh = process_mesh
original_op_dist_attr = copy.deepcopy(op_dist_attr)
input_xshape_arg_names = []
if "XShape" in op_desc.input_names():
input_xshape_arg_names = op_desc.input("XShape")
for arg_name in serial_op.input_arg_names:
serial_tensor = dist_op.get_serial_input(arg_name)
if not serial_tensor.is_parameter:
if arg_name not in input_xshape_arg_names:
old_dims_mapping = op_dist_attr.get_input_dims_mapping(
arg_name)
if len(old_dims_mapping) > 0:
new_dims_mapping = [0] + [
-1 for _ in range(len(old_dims_mapping) - 1)
]
op_dist_attr.set_input_dims_mapping(
arg_name, new_dims_mapping)
else:
old_dims_mapping = op_dist_attr.get_input_dims_mapping(
arg_name)
if len(old_dims_mapping) > 1:
new_dims_mapping = [-1, 0] + [
-1 for _ in range(len(old_dims_mapping) - 2)
]
op_dist_attr.set_input_dims_mapping(
arg_name, new_dims_mapping)
# Set tensor's dims_mapping by the op's
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_tensor)
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
arg_name)
output_xshape_arg_names = []
if "XShape" in op_desc.output_names():
output_xshape_arg_names = op_desc.output("XShape")
for arg_name in serial_op.output_arg_names:
serial_tensor = dist_op.get_serial_output(arg_name)
if not serial_tensor.is_parameter:
if arg_name not in output_xshape_arg_names:
old_dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name)
if len(old_dims_mapping) > 0:
new_dims_mapping = [0] + [
-1 for _ in range(len(old_dims_mapping) - 1)
]
op_dist_attr.set_output_dims_mapping(
arg_name, new_dims_mapping)
else:
old_dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name)
if len(old_dims_mapping) > 1:
new_dims_mapping = [-1, 0] + [
-1 for _ in range(len(old_dims_mapping) - 2)
]
op_dist_attr.set_output_dims_mapping(
arg_name, new_dims_mapping)
# Set tensor's dims_mapping by the op's
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_tensor)
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name)
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, partial=False)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
for op_dist_impl in op_dist_impls:
op_dist_impl.update_dims_mapping(dist_op)
if op_dist_impl.is_auto_compatible(dist_op) \
and dist_op.validate_dist_attr():
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
else:
dist_op.dist_attr = original_op_dist_attr
def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
......
......@@ -110,6 +110,7 @@ class DistributedContext:
# self._tensor_id_to_tensor_node_ids = {}
self._is_initialized = False
#TODO: need a better way to remove the following flag
self._need_copy_dist_attr_to_graph = False
self._backup_pass_context_stack = []
self._backup_block_state_stack = []
......@@ -121,6 +122,9 @@ class DistributedContext:
# flag whether scale gradient with dp size
self._gradient_scale = True
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False
@property
def serial_main_program(self):
return self._serial_main_program
......@@ -198,6 +202,14 @@ class DistributedContext:
def gradient_scale(self, gs):
self._gradient_scale = gs
@property
def data_parallel(self):
return self._data_parallel
@data_parallel.setter
def data_parallel(self, dp):
self._data_parallel = dp
def _backup_serial_info(self, mode):
self._backup_serial_main_program_stack.append(
self._serial_main_program.clone())
......@@ -335,7 +347,7 @@ class DistributedContext:
if dist:
self._restore_dist_info(dist_mode)
def initialize(self):
def initialize(self, with_graph=True):
if not self._is_initialized:
if not self._serial_main_program:
self._serial_main_program = self._original_serial_main_program
......@@ -366,13 +378,16 @@ class DistributedContext:
self._dist_ops_for_program)
self._tensors_ids = list(self._dist_tensors_for_program.keys())
self._ops_ids = list(self._dist_ops_for_program.keys())
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_main_program.desc))
self._init_dist_attr_for_graph()
self._is_initialized = True
self._need_copy_dist_attr_to_graph = False
if self._need_copy_dist_attr_to_graph:
if with_graph:
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_main_program.desc))
self._init_dist_attr_for_graph()
self._need_copy_dist_attr_to_graph = False
if self._need_copy_dist_attr_to_graph and with_graph:
self.copy_dist_attr_from_program_to_graph()
def add_process_mesh(self, process_mesh):
......@@ -522,6 +537,8 @@ class DistributedContext:
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
else:
default_ctx = self
# Copy the data parallel flag from the default context
self._data_parallel = default_ctx.data_parallel
for block in self._serial_main_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
......
......@@ -44,7 +44,7 @@ from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .process_group import get_all_process_groups, get_world_process_group
from .process_group import new_process_group, get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
......@@ -155,8 +155,10 @@ class Engine:
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
......
......@@ -27,10 +27,14 @@ class Planner:
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion.
# TODO: The id mapping will be lost if we clone the original program.
default_ctx = get_default_distributed_context()
self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.initialize()
if not default_ctx.data_parallel:
# Use SSA graph for complex parallism
self._dist_context.initialize(with_graph=True)
else:
# Use program for data parallel parallism
self._dist_context.initialize(with_graph=False)
self._completer = Completer(self._dist_context)
......
......@@ -20,6 +20,11 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 80)
py_test_modules(test_engine_api_dp MODULES test_engine_api_dp ENVS
${dist_ENVS})
set_tests_properties(test_engine_api_dp
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
......
......@@ -90,7 +90,7 @@ class MLPLayer(nn.Layer):
def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)[0]
out = self.linear0(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)[0]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import time
import tempfile
import copy
import os
import numpy as np
import subprocess
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static()
batch_size = 1
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10
paddle.seed(44)
class MyDataset(Dataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
return out
def train(fetch):
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
fetches=fetches)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
temp_dir.cleanup()
if __name__ == "__main__":
train(True)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestEngineAPI(unittest.TestCase):
def test_engine_api(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "engine_api_dp.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name,
launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册