未验证 提交 46212b80 编写于 作者: C caozhou 提交者: GitHub

add update func of auto search (#37867)

* add update func of auto search

* update unitest
上级 6b48dfe9
...@@ -1036,3 +1036,139 @@ def set_grad_var_shape(program, dist_context): ...@@ -1036,3 +1036,139 @@ def set_grad_var_shape(program, dist_context):
if list(grad_var.shape) != ref_shape: if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape) grad_var.desc.set_shape(ref_shape)
def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False
op_dist_attr = dist_op.dist_attr
op_desc = dist_op.serial_op.desc
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return changed
def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
changed = False
op_dist_attr = dist_op.dist_attr
op_desc = dist_op.serial_op.desc
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
return changed
...@@ -92,6 +92,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) ...@@ -92,6 +92,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp)
...@@ -257,6 +258,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -257,6 +258,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp)
...@@ -643,6 +645,7 @@ if(WITH_DISTRIBUTE) ...@@ -643,6 +645,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
# import os
# import copy
# import json
import unittest
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
# from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
# from paddle.distributed.auto_parallel.cluster import Cluster
# from paddle.distributed.auto_parallel.utils import SerialProgramInfo
# from paddle.distributed.auto_parallel.searcher import Checker, Enumerater
from paddle.distributed.auto_parallel.dist_context import DistributedContext
# from paddle.distributed.auto_parallel.utils import get_all_distributed_main_program
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_default_dist_impl
from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_elementwise_like_dist_impl
paddle.enable_static()
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = paddle.unsqueeze(out, axis=0)
out = paddle.reshape(out, [4, 1024])
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = loss_func(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def set_default_dist_attr(program, dist_context, process_mesh):
ops = program.global_block().ops
vars = program.global_block().vars
for op in ops:
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = process_mesh
for var_name in op.input_arg_names:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape]
dist_context.set_tensor_dist_attr_for_program(vars[var_name],
tensor_dist_attr)
op_dist_attr.set_input_dims_mapping(var_name,
tensor_dist_attr.dims_mapping)
for var_name in op.output_arg_names:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape]
dist_context.set_tensor_dist_attr_for_program(vars[var_name],
tensor_dist_attr)
op_dist_attr.set_output_dims_mapping(var_name,
tensor_dist_attr.dims_mapping)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
dist_context.add_process_mesh(process_mesh)
class TestMLPSearcher(unittest.TestCase):
def test_update(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
_, train_program, startup_program = mlp_forward(train_program,
startup_program)
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_context = DistributedContext()
set_default_dist_attr(train_program, dist_context, global_process_mesh)
ops = train_program.global_block().ops
vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.completion import is_elementwise_like_op
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
for op in ops:
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
if dist_op_impl_container is None:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_op = DistributedOperator(op, op_dist_attr)
if is_elementwise_like_op(op.type):
changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op)
self.assertFalse(changed)
dist_op.dist_attr.set_output_dims_mapping(
op.output_arg_names[0], [0] + [
-1
for i in range(
1, len(vars[op.output_arg_names[0]].shape))
])
try:
changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op)
except:
continue
self.assertTrue(changed)
else:
changed = update_op_dims_mapping_by_default_dist_impl(
dist_op)
self.assertFalse(changed)
dist_op.dist_attr.set_output_dims_mapping(
op.output_arg_names[0], [0] + [
-1
for i in range(
1, len(vars[op.output_arg_names[0]].shape))
])
try:
changed = update_op_dims_mapping_by_default_dist_impl(
dist_op)
except:
continue
self.assertTrue(changed)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册