未验证 提交 325fdf1d 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update rule based tuner (#51908)

* add patterns

* update rule based tuner

* add forward sub program completion

* add unittest

* add bwd sub program completion
上级 13b8b5e0
...@@ -64,6 +64,7 @@ class DistributedContext: ...@@ -64,6 +64,7 @@ class DistributedContext:
fetch_vars={}, fetch_vars={},
cluster=None, cluster=None,
strategy=None, strategy=None,
json_config=None,
): ):
# Data members related to original programs (unchanged) # Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog self._original_serial_main_program = serial_main_prog
...@@ -129,6 +130,8 @@ class DistributedContext: ...@@ -129,6 +130,8 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel # A flag indicates whether the used parallelism is data parallel
self._data_parallel = False self._data_parallel = False
self._json_config = json_config
@property @property
def serial_main_program(self): def serial_main_program(self):
return self._serial_main_program return self._serial_main_program
...@@ -181,6 +184,10 @@ class DistributedContext: ...@@ -181,6 +184,10 @@ class DistributedContext:
def process_meshes(self): def process_meshes(self):
return self._process_meshes return self._process_meshes
@process_meshes.setter
def process_meshes(self, val):
self._process_meshes = val
@property @property
def pass_context(self): def pass_context(self):
return self._pass_context return self._pass_context
...@@ -397,7 +404,7 @@ class DistributedContext: ...@@ -397,7 +404,7 @@ class DistributedContext:
if dist: if dist:
self._restore_dist_info(dist_mode) self._restore_dist_info(dist_mode)
def initialize(self, with_graph=True, with_cpp=False): def initialize(self, with_graph=True, with_cpp=False, no_default=False):
if not self._is_initialized: if not self._is_initialized:
if not self._serial_main_program: if not self._serial_main_program:
if self._original_serial_main_program: if self._original_serial_main_program:
...@@ -418,7 +425,7 @@ class DistributedContext: ...@@ -418,7 +425,7 @@ class DistributedContext:
if not self._serial_fetch_vars: if not self._serial_fetch_vars:
self._restore_serial_fetch_vars() self._restore_serial_fetch_vars()
self._init_dist_attr_for_program() self._init_dist_attr_for_program(no_default)
# Backup the original distributed information for later restore # Backup the original distributed information for later restore
self._original_dist_tensors_for_program = copy.deepcopy( self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program self._dist_tensors_for_program
......
...@@ -174,7 +174,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -174,7 +174,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname varname
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"] var_names = [varname + "@GRAD"]
......
...@@ -278,6 +278,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -278,6 +278,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
for mapping in ids_dims_mapping[1:]: for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping): if is_dim_shard(mapping):
return False return False
if is_dim_shard(ids_dims_mapping[0]) and is_dim_shard(
w_dims_mapping[-2]
):
if ids_dims_mapping[0] == w_dims_mapping[-2]:
return False
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
......
...@@ -1507,7 +1507,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1507,7 +1507,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
processes = process_mesh.process_ids processes = process_mesh.process_ids
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
if backward_op.attr("trans_y"): if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse() Y_var_dim_mapping = list(reversed(Y_var_dim_mapping))
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1] parallel_axis = Y_var_dim_mapping[1]
......
...@@ -12,10 +12,19 @@ ...@@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from ..cost import (
_g_op_cost_factory,
build_comp_costs_from_descs,
build_comp_desc_from_dist_op,
build_dp_costs,
)
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .common import ( from .common import (
DistributedOperatorImpl, DistributedOperatorImpl,
DistributedOperatorImplContainer, DistributedOperatorImplContainer,
is_parameter_related,
register_distributed_operator_impl, register_distributed_operator_impl,
register_distributed_operator_impl_container, register_distributed_operator_impl_container,
) )
...@@ -42,6 +51,84 @@ class DistributedScaleImpl(DistributedOperatorImpl): ...@@ -42,6 +51,84 @@ class DistributedScaleImpl(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
return True return True
def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.process_ids
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
main_block = backward_op.block
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname
)
mesh_shape = process_mesh.shape
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(
res,
dist_op,
ctx,
var_names,
attrs,
parallel_axis,
cluster,
)
return res
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
return True return True
......
...@@ -127,6 +127,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -127,6 +127,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16) py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver) py_test_modules(test_dist_saver MODULES test_dist_saver)
py_test_modules(test_engine_save_load MODULES test_engine_save_load) py_test_modules(test_engine_save_load MODULES test_engine_save_load)
py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner)
# End of unittests WITH single card WITHOUT timeout # End of unittests WITH single card WITHOUT timeout
endif() endif()
...@@ -178,6 +178,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -178,6 +178,7 @@ class TestDistOpCost(unittest.TestCase):
[None, None], [None, None],
) )
tmp_out = paddle.matmul(out1, tmp_param) tmp_out = paddle.matmul(out1, tmp_param)
tmp_out = paddle.scale(tmp_out, 0.5)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1] out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1]
...@@ -286,6 +287,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -286,6 +287,7 @@ class TestDistOpCost(unittest.TestCase):
) )
tmp_out = paddle.matmul(out1, tmp_param) tmp_out = paddle.matmul(out1, tmp_param)
tmp_out = paddle.scale(tmp_out, 0.5)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1] out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1]
......
...@@ -119,9 +119,10 @@ class TestGroupOperators(unittest.TestCase): ...@@ -119,9 +119,10 @@ class TestGroupOperators(unittest.TestCase):
RuleBasedTuner, RuleBasedTuner,
) )
dist_context = DistributedContext() dist_context = DistributedContext(train_program)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) layers = tuner.cluster_operators()
op_types = [] op_types = []
for layer in layers: for layer in layers:
tmp = [] tmp = []
......
...@@ -112,18 +112,11 @@ class TestGroupOperatorsAndPatterns(unittest.TestCase): ...@@ -112,18 +112,11 @@ class TestGroupOperatorsAndPatterns(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, _PATTERNS,
GraphUtil, GraphUtil,
RuleBasedTuner,
) )
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
graph = GraphUtil.convert_to_graph(train_program.global_block()) graph = GraphUtil.convert_to_graph(train_program.global_block())
print("graph: ", graph) print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
import paddle.static as static
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def get_gpt_model(
train_program, start_program, place, batch_size, sequence_len, vocab_size
):
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
position_ids = paddle.static.data(
name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
)
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32',
)
labels = paddle.static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64'
)
loss_mask = paddle.static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
def gen_data():
np.random.seed(2021)
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
return tokens, position_ids, attention_mask, labels, loss_mask
return train_program, start_program, loss, gen_data
class TestRuleBasedTuner(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
place,
batch_size,
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner,
)
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
dist_context = DistributedContext(
serial_main_prog=train_program,
serial_startup_prog=start_program,
serial_optimizer=opt,
serial_loss=loss,
)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context)
tuner.cluster_operators()
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()
if __name__ == "__main__":
unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
_FLOPS_COMPUTE_FUNC_MAP = {} _FLOPS_COMPUTE_FUNC_MAP = {}
...@@ -244,8 +245,12 @@ def _matmul_flops(input_shapes, attrs): ...@@ -244,8 +245,12 @@ def _matmul_flops(input_shapes, attrs):
equation: flops = 2 * numel(output) * dim_n equation: flops = 2 * numel(output) * dim_n
""" """
x_shape = input_shapes.get("X", input_shapes.get("x", [[0]]))[0] x_shape = copy.deepcopy(
y_shape = input_shapes.get("Y", input_shapes.get("y", [[0]]))[0] input_shapes.get("X", input_shapes.get("x", [[0]]))[0]
)
y_shape = copy.deepcopy(
input_shapes.get("Y", input_shapes.get("y", [[0]]))[0]
)
if attrs.get('transpose_X') or attrs.get('transpose_x'): if attrs.get('transpose_X') or attrs.get('transpose_x'):
x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
...@@ -276,11 +281,11 @@ def _matmul_v2_flops(input_shapes, attrs): ...@@ -276,11 +281,11 @@ def _matmul_v2_flops(input_shapes, attrs):
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m] shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
equation: flops = 2 * numel(outputs) * dim_n equation: flops = 2 * numel(outputs) * dim_n
""" """
x_shape = input_shapes.get('X')[0] x_shape = copy.deepcopy(input_shapes.get('X')[0])
y_shape = input_shapes.get('Y')[0] y_shape = copy.deepcopy(input_shapes.get('Y')[0])
if attrs.get('trans_x') is not None: if attrs.get('trans_x'):
x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
if attrs.get('trans_y') is not None: if attrs.get('trans_y'):
y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1] y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
dim_x = len(x_shape) dim_x = len(x_shape)
dim_y = len(y_shape) dim_y = len(y_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册