# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import paddle import paddle.nn.functional as F from paddle import nn, static, utils from paddle.distributed import fleet from paddle.distributed.auto_parallel.static.completion import Completer from paddle.distributed.auto_parallel.static.cost_model import estimate_cost from paddle.distributed.auto_parallel.static.dist_context import ( DistributedContext, ) from paddle.distributed.auto_parallel.static.parallelizer import ( AutoParallelizer, ) from paddle.distributed.auto_parallel.static.partitioner import Partitioner from paddle.distributed.auto_parallel.static.reshard import Resharder from paddle.distributed.fleet import auto from paddle.fluid import core paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"]) PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"]) NUM_RANKS = 8 STAGE_0_CNT = 5 STAGE_1_CNT = 10 pp_cfg = [[0, 1, 4, 5], [2, 3, 6, 7]] device = "gpu" if core.is_compiled_with_cuda() else "cpu" class MLPLayer(nn.Layer): def __init__( self, hidden_size=256, intermediate_size=4 * 256, initializer_range=0.02, is_distributed=True, ): super().__init__() d_model = hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr( initializer=nn.initializer.Normal(mean=0.0, std=initializer_range) ) bias_attr = None self.linear0 = nn.Linear( d_model, dim_feedforward, weight_attr, bias_attr=bias_attr ) self.linear1 = nn.Linear( dim_feedforward, d_model, weight_attr, bias_attr=bias_attr ) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.is_distributed = is_distributed def forward(self, input): if self.is_distributed: auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None]) out = self.norm(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = self.linear1(out) return out def get_single_node_data(): train_program = paddle.static.Program() startup_program = paddle.static.Program() loss, train_program, startup_program = mlp_forward( train_program, startup_program, is_distributed=False ) cost_model = core.CostModel() cost_data = cost_model.profile_measure( train_program, startup_program, device, ["time"] ) op_name2cost = [{}, {}] for idx, op in enumerate(train_program.blocks[0].ops): if idx <= STAGE_0_CNT: op_name2cost[0][op.type] = cost_data.get_op_time_ms(idx) elif idx <= STAGE_1_CNT: op_name2cost[1][op.type] = cost_data.get_op_time_ms(idx) return op_name2cost def mlp_forward(train_program, start_program, is_distributed=True): with static.program_guard( train_program, start_program ), utils.unique_name.guard(): batch_size = 4 hidden_size = 256 sequence_len = 128 if is_distributed: input = static.data( name="input", shape=[batch_size, hidden_size], dtype='float32' ) label = static.data( name="label", shape=[batch_size, 1], dtype='float32' ) else: input = paddle.ones( name="input", shape=[batch_size, hidden_size], dtype='float32' ) label = paddle.ones( name="label", shape=[batch_size, 1], dtype='float32' ) if is_distributed: auto.shard_tensor(input, PP_MESH_0, ["x", None]) auto.shard_tensor(label, PP_MESH_1, ["x", None]) mlp = MLPLayer( hidden_size=hidden_size, intermediate_size=4 * hidden_size, initializer_range=0.02, is_distributed=is_distributed, ) predict = mlp(input) error_cost = paddle.nn.functional.square_error_cost(predict, label) loss = paddle.mean(error_cost) return loss, train_program, start_program def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward( train_program, startup_program ) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.optimizer.Adam() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program ) dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None, ) # logical partition partitioner = Partitioner(dist_context, rank_id) ( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads, ) = partitioner.partition( complete_train_program, startup_program, params_grads ) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads ) return ( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads, ) def check_runtime_estimation(cost): return cost.runtime > 0 def check_memory_estimation(cost): for i in range(NUM_RANKS): if cost.static_mem[i] <= 0 or cost.peak_mem[i] <= 0: return False if cost.static_mem[i] > cost.peak_mem[i]: return False return True def check_empty_program_runtime(cost): return cost.runtime == 0 def check_empty_program_memory(cost): for mem in cost.peak_mem: if mem > 1: return False for mem in cost.static_mem: if mem > 1: return False return True class TestCostModel(unittest.TestCase): def test_empty_program_cost_model(self): empty_program = paddle.static.Program() startup_program = paddle.static.Program() standalone_cost_data = [{}] empty_pp_cfg = None cluster = None cost = estimate_cost( [empty_program], cluster=cluster, pipeline_config=empty_pp_cfg, standalone_cost_data=standalone_cost_data, batch_size=1, ) self.assertTrue(check_empty_program_runtime(cost)) self.assertTrue(check_empty_program_memory(cost)) def test_auto_parallel_cost_model(self): standalone_cost_data = get_single_node_data() dist_program = [] for rank_id in range(NUM_RANKS): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() ( distributed_program, dist_startup_prog, dist_params_grads, ) = get_dist_prog( train_program, startup_program, dist_context, rank_id ) resharder = Resharder( distributed_program, dist_startup_prog, rank_id, dist_context, dist_params_grads, ) resharder.reshard() dist_program.append(distributed_program) cluster = None cost = estimate_cost( dist_program, cluster=cluster, pipeline_config=pp_cfg, standalone_cost_data=standalone_cost_data, batch_size=4, ) self.assertTrue(check_runtime_estimation(cost)) self.assertTrue(check_memory_estimation(cost)) if __name__ == "__main__": unittest.main()