diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index a3505eae876ef2aae662aa82dbd8f4bac97d1aee..92918c834a5da7cafecb237130b7e5bee3990220 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1036,3 +1036,139 @@ def set_grad_var_shape(program, dist_context): if list(grad_var.shape) != ref_shape: grad_var.desc.set_shape(ref_shape) + + +def update_op_dims_mapping_by_default_dist_impl(dist_op): + changed = False + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + # The following statement will be replaced by a more elegent way + if op_desc.type() == "shape" or op_desc.type() == "slice": + return False + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + batch_dim_mappings = [] + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for idx, mapping in enumerate(dims_mapping[1:]): + assert mapping == -1, \ + "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[0]) + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for idx, mapping in enumerate(dims_mapping[1:]): + assert mapping == -1, \ + "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[0]) + else: + assert dims_mapping[0] == -1, \ + "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ + .format(op_desc.type(), mapping) + if len(dims_mapping) > 2: + for idx, mapping in enumerate(dims_mapping[2:]): + assert mapping == -1, \ + "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[1]) + + compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) + assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + else: + if compatible_dim_mapping != dims_mapping[1]: + dims_mapping[1] = compatible_dim_mapping + changed = True + + return changed + + +def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): + changed = False + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + input_arg_names = op_desc.input_arg_names() + input_dims_mapping_dict = {} + input_dims_mapping_lens = {} + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + input_dims_mapping_dict[arg_name] = dims_mapping + input_dims_mapping_lens[arg_name] = len(dims_mapping) + + dims_mapping_list = [] + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] + dims_mapping_list.append(new_dims_mapping) + else: + dims_mapping_list.append(input_dims_mapping_dict[arg_name]) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) + assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(input_dims_mapping_lens[arg_name]) + ] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[i] = compatible_dims_mapping[new_idx] + if new_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + changed = True + else: + if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if compatible_dims_mapping != dims_mapping: + op_dist_attr.set_output_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + return changed diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4162f697d27eacbd427bd6fb987e87a8437e86d3..5fdcd6d0a9d385cf4cd3238fda551a4da9beffdb 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -92,6 +92,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) @@ -257,6 +258,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) @@ -643,6 +645,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..665a16c862c8481fd50ee04ceaf8e069a60bfbbf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +# import os +# import copy +# import json +import unittest + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +# from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +# from paddle.distributed.auto_parallel.cluster import Cluster +# from paddle.distributed.auto_parallel.utils import SerialProgramInfo +# from paddle.distributed.auto_parallel.searcher import Checker, Enumerater +from paddle.distributed.auto_parallel.dist_context import DistributedContext +# from paddle.distributed.auto_parallel.utils import get_all_distributed_main_program +from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute +from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute +from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_default_dist_impl +from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_elementwise_like_dist_impl + +paddle.enable_static() + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = paddle.unsqueeze(out, axis=0) + out = paddle.reshape(out, [4, 1024]) + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + loss_func = paddle.nn.CrossEntropyLoss(reduction="none") + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(input) + error_cost = loss_func(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def set_default_dist_attr(program, dist_context, process_mesh): + ops = program.global_block().ops + vars = program.global_block().vars + for op in ops: + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = process_mesh + for var_name in op.input_arg_names: + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = process_mesh + tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape] + dist_context.set_tensor_dist_attr_for_program(vars[var_name], + tensor_dist_attr) + op_dist_attr.set_input_dims_mapping(var_name, + tensor_dist_attr.dims_mapping) + + for var_name in op.output_arg_names: + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = process_mesh + tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape] + dist_context.set_tensor_dist_attr_for_program(vars[var_name], + tensor_dist_attr) + op_dist_attr.set_output_dims_mapping(var_name, + tensor_dist_attr.dims_mapping) + dist_context.set_op_dist_attr_for_program(op, op_dist_attr) + + dist_context.add_process_mesh(process_mesh) + + +class TestMLPSearcher(unittest.TestCase): + def test_update(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + _, train_program, startup_program = mlp_forward(train_program, + startup_program) + global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + dist_context = DistributedContext() + set_default_dist_attr(train_program, dist_context, global_process_mesh) + ops = train_program.global_block().ops + vars = train_program.global_block().vars + from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container + from paddle.distributed.auto_parallel.completion import is_elementwise_like_op + from paddle.distributed.auto_parallel.dist_op import DistributedOperator + + for op in ops: + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + if dist_op_impl_container is None: + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + dist_op = DistributedOperator(op, op_dist_attr) + if is_elementwise_like_op(op.type): + changed = update_op_dims_mapping_by_elementwise_like_dist_impl( + dist_op) + self.assertFalse(changed) + + dist_op.dist_attr.set_output_dims_mapping( + op.output_arg_names[0], [0] + [ + -1 + for i in range( + 1, len(vars[op.output_arg_names[0]].shape)) + ]) + try: + changed = update_op_dims_mapping_by_elementwise_like_dist_impl( + dist_op) + except: + continue + self.assertTrue(changed) + else: + changed = update_op_dims_mapping_by_default_dist_impl( + dist_op) + self.assertFalse(changed) + + dist_op.dist_attr.set_output_dims_mapping( + op.output_arg_names[0], [0] + [ + -1 + for i in range( + 1, len(vars[op.output_arg_names[0]].shape)) + ]) + try: + changed = update_op_dims_mapping_by_default_dist_impl( + dist_op) + except: + continue + self.assertTrue(changed) + + +if __name__ == "__main__": + unittest.main()