提交 f2226441 编写于 作者: C caozhou 提交者: GitHub

【Auto Parallel】Update Planner (#39201)

* update planner

* update unitest

* update dist matmul

* update auto converter
上级 2b9bb8bb
...@@ -165,6 +165,18 @@ def _is_auto_compatible_for_matmul(dist_op): ...@@ -165,6 +165,18 @@ def _is_auto_compatible_for_matmul(dist_op):
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
y_dims_mapping.insert(1, -1) y_dims_mapping.insert(1, -1)
# NOTE: Partition is not supported if matmul op has trans.
if op_desc.type() == "matmul_v2":
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
if x_dims_mapping[-2:] != [-1, -1] or y_dims_mapping[
-2:] != [-1, -1]:
return False
elif op_desc.type() == "matmul":
if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'):
if x_dims_mapping[-2:] != [-1, -1] or y_dims_mapping[
-2:] != [-1, -1]:
return False
# Deal with dim > 2 and take care of broadcasting # Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2: if out_dims_mapping_len > 2:
broadcast_x_dims_mapping = [] broadcast_x_dims_mapping = []
...@@ -550,7 +562,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -550,7 +562,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[1] Weight_var.name)[-1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping) matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -775,7 +787,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -775,7 +787,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0] Weight_var.name)[-2]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping) matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -1064,7 +1076,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1064,7 +1076,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[1] Weight_var.name)[-1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping) matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -1283,7 +1295,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1283,7 +1295,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0] Weight_var.name)[-2]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping) matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
......
...@@ -84,34 +84,22 @@ class PlanFilter: ...@@ -84,34 +84,22 @@ class PlanFilter:
@staticmethod @staticmethod
def check_dims_mapping_for_special_op(op, op_dist_attr, vars): def check_dims_mapping_for_special_op(op, op_dist_attr, vars):
if op.type == "layer_norm": # NOTE: Those ops has some partition limits, and will be solved when corresponding dist op implemented in the future.
bias_dims_mapping = op_dist_attr.get_input_dims_mapping( if op.type == "elementwise_add" or op.type == 'layer_norm' or op.type == "softmax_with_cross_entropy":
op.input("Bias")[0]) for name in op.input_arg_names:
scale_dims_mapping = op_dist_attr.get_input_dims_mapping( for item in op_dist_attr.get_input_dims_mapping(name):
op.input("Scale")[0]) if item != -1:
x_dims_mapping = op_dist_attr.get_input_dims_mapping(
op.input("X")[0])
mean_dims_mapping = op_dist_attr.get_output_dims_mapping(
op.output("Mean")[0])
variance_dims_mapping = op_dist_attr.get_output_dims_mapping(
op.output("Variance")[0])
y_dims_mapping = op_dist_attr.get_output_dims_mapping(
op.output("Y")[0])
if x_dims_mapping != y_dims_mapping:
return False
if scale_dims_mapping[0] != x_dims_mapping[-1]:
return False
if bias_dims_mapping[0] != y_dims_mapping[-1]:
return False return False
for name in op.output_arg_names:
if mean_dims_mapping[0] != x_dims_mapping[0]: for item in op_dist_attr.get_output_dims_mapping(name):
if item != -1:
return False return False
if op.type == "lookup_table_v2":
if variance_dims_mapping[0] != x_dims_mapping[0]: for name in op.input_arg_names:
if name == 'pos_embeddings':
for item in op_dist_attr.get_input_dims_mapping(name):
if item != -1:
return False return False
return True return True
...@@ -426,13 +414,14 @@ class MCMC(SearchAlgorithm): ...@@ -426,13 +414,14 @@ class MCMC(SearchAlgorithm):
var_name) == dims_mapping: var_name) == dims_mapping:
dist_context.set_op_dist_attr_for_program( dist_context.set_op_dist_attr_for_program(
search_op, op_dist_attr) search_op, op_dist_attr)
for name in search_op.output_arg_names:
tensor_dist_attr = TensorDistributedAttribute( tensor_dist_attr = TensorDistributedAttribute(
) )
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name) name)
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr) vars[name], tensor_dist_attr)
has_changed = True has_changed = True
break break
if has_changed: if has_changed:
......
...@@ -593,8 +593,10 @@ def load_parameter_into_program(param_dict, program): ...@@ -593,8 +593,10 @@ def load_parameter_into_program(param_dict, program):
param_dict(dict): parameters' name and value. param_dict(dict): parameters' name and value.
program(Program): the program to be updated program(Program): the program to be updated
""" """
_check_param_dict(param_dict) assert isinstance(param_dict, dict)
assert program and isinstance(program, paddle.fluid.framework.Program) assert program and isinstance(program, paddle.fluid.framework.Program)
if not param_dict:
return
program.set_state_dict(param_dict) program.set_state_dict(param_dict)
...@@ -705,7 +707,6 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -705,7 +707,6 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
dist_param_dict(dict): parameters' value of current rank. dist_param_dict(dict): parameters' value of current rank.
""" """
assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None."
assert _check_dist_attr(cur_dist_attr), "'pre_dist_attr' cannot be None."
assert isinstance(dist_param_dict, dict), \ assert isinstance(dist_param_dict, dict), \
"The type of 'dist_param_dict' should be 'dict', but got {}.".format( "The type of 'dist_param_dict' should be 'dict', but got {}.".format(
str(type(dist_param_dict))) str(type(dist_param_dict)))
...@@ -720,6 +721,9 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -720,6 +721,9 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
"The value of 'dist_param_dict' is parameter's value of all ranks, " "The value of 'dist_param_dict' is parameter's value of all ranks, "
"and its type should be 'list(numpy.ndarray)'.") "and its type should be 'list(numpy.ndarray)'.")
if cur_dist_attr is None:
return {}
param_not_in_pre = [] param_not_in_pre = []
param_not_in_cur = [] param_not_in_cur = []
logging.info("Start to merge and slice parameters.") logging.info("Start to merge and slice parameters.")
...@@ -1268,6 +1272,7 @@ def get_all_distributed_main_program(serial_program_info, dist_context, ...@@ -1268,6 +1272,7 @@ def get_all_distributed_main_program(serial_program_info, dist_context,
used_dist_context._dist_op_context = DistributedOperatorContext() used_dist_context._dist_op_context = DistributedOperatorContext()
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
rank_id, used_dist_context) rank_id, used_dist_context)
# print("dist_main_program: ", dist_main_program)
all_dist_main_program.append(dist_main_program) all_dist_main_program.append(dist_main_program)
return all_dist_main_program return all_dist_main_program
......
...@@ -5,4 +5,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -5,4 +5,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS}) py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
endif() endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.static as static
from paddle.distributed import fleet
import sys
import numpy as np
import paddle.distributed.auto_parallel as auto
from auto_parallel_relaunch_model import mlp_pretrain_forward
from auto_parallel_relaunch_model import batch_generator_creator
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
def get_gpt_model(train_program, start_program, place, batch_size, sequence_len,
vocab_size):
modeling.init_global()
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64')
position_ids = paddle.static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64')
loss_mask = paddle.static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32')
data_holder = [tokens, position_ids, attention_mask, labels, loss_mask]
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
def gen_data():
np.random.seed(2021)
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
return tokens, position_ids, attention_mask, labels, loss_mask
return train_program, start_program, loss, gen_data
def train():
dist_strategy = fleet.DistributedStrategy()
# init parallel optimizer
dist_strategy.auto_search = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
place = paddle.set_device("gpu")
gpus = [0, 1]
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program, start_program, loss, gen_data = get_gpt_model(
train_program, start_program, place, batch_size, sequence_len,
vocab_size)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
places = static.cuda_places()
exe = paddle.static.Executor(places[0])
exe.run(distributed_startup_program)
for step in range(10):
tokens, position_ids, attention_mask, labels, loss_mask = gen_data()
if loss.name in distributed_main_program.global_block().vars:
loss_print, = exe.run(distributed_main_program,
feed={
"tokens": tokens,
"position_ids": position_ids,
"attention_mask": attention_mask,
"labels": labels,
"loss_mask": loss_mask
},
fetch_list=[loss])
print("step: %s, loss: %f" % (step, loss_print[0]))
else:
exe.run(distributed_main_program,
feed={
"tokens": tokens,
"position_ids": position_ids,
"attention_mask": attention_mask,
"labels": labels,
"loss_mask": loss_mask
})
print("step: %s, loss: %s" % (step, "None"))
if __name__ == "__main__":
train()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import json
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestPlannerReLaunch(unittest.TestCase):
def test_relaunch_with_planner(self):
from test_auto_parallel_relaunch import cluster_json
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
launch_model_path = os.path.join(
file_dir, "auto_parallel_relaunch_with_gpt_planner.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path,
"--enable_auto_mapping", "True", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(file_dir,
"auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
...@@ -34,6 +34,7 @@ paddle.enable_static() ...@@ -34,6 +34,7 @@ paddle.enable_static()
def init_global(): def init_global():
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = None
global _global_process_mesh global _global_process_mesh
global PP_MESH_LIST global PP_MESH_LIST
global DPPP_MESH_LIST global DPPP_MESH_LIST
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册