From e379455a1ea6f5264dcb8326dadebc381540439b Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 12 Jul 2022 16:32:13 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Auto=20Parallel=E3=80=91update=20base?= =?UTF-8?q?=20cost=20(#44095)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update base cost * update unittest of cost model * add unittest --- .../auto_parallel/cost/base_cost.py | 365 +++++++++++++++--- .../unittests/auto_parallel/CMakeLists.txt | 1 + .../unittests/auto_parallel/test_base_cost.py | 234 +++++++++++ .../unittests/auto_parallel/test_cluster.py | 4 + .../unittests/auto_parallel/test_comm_cost.py | 4 + .../auto_parallel/test_new_cost_model.py | 28 +- 6 files changed, 586 insertions(+), 50 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py diff --git a/python/paddle/distributed/auto_parallel/cost/base_cost.py b/python/paddle/distributed/auto_parallel/cost/base_cost.py index 4455d6f6648..deac76e45a8 100644 --- a/python/paddle/distributed/auto_parallel/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/base_cost.py @@ -17,8 +17,12 @@ from functools import reduce import paddle -from ..cluster import LinkType +from ..utils import _get_comm_group, _get_corresponding_rank from ..process_group import get_process_group +from ..cluster import LinkType +from ..dist_tensor import DistributedTensor +from ..utils import _get_idx_in_axis +from ..dist_tensor import DistributedTensor COMM_OP_TYPE = [ "send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum", @@ -28,33 +32,22 @@ NON_COMP_TYPE = ["while"] + COMM_OP_TYPE _g_op_cost_factory = {} -def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None): - desc = {} - desc["op"] = op_type - desc["group_ranks"] = group_ranks - desc["inputs"] = {"X": [(dtype, shape)]} - if attrs is not None: - desc["attrs"] = attrs - return desc - +def build_comp_desc_from_op(op): + """Build the description of computation op.""" + # NOTE: The desc is for serial op. + from ..reshard import get_var_with_recursion -def _parse_op_to_desc(op, dist_context=None): desc = {} - desc["op"] = op.type + # The desc of concat op is {"op": "concat", "inputs": {"X": [(paddle.float32, [20, 20]), (paddle.float32, [20, 20])]}, "outputs": {"Out": [(paddle.float32, [20, 40])], "attrs": {"axis": -1}}} vars = op.block.vars + desc["op"] = op.type input_desc = OrderedDict() for input_name in op.input_names: var_name_list = op.input(input_name) var_desc = [] for var_name in var_name_list: - var = vars[var_name] - shape = None - if dist_context is not None: - dist_tensor = dist_context.get_dist_tensor_for_program(var) - shape = dist_tensor.local_sizes() - else: - shape = var.shape - assert shape is not None + var = get_var_with_recursion(var_name, op.block, op.block.program) + shape = var.shape var_desc.append((var.dtype, shape)) input_desc[input_name] = var_desc desc["inputs"] = input_desc @@ -64,14 +57,8 @@ def _parse_op_to_desc(op, dist_context=None): var_name_list = op.output(out_name) var_desc = [] for var_name in var_name_list: - var = vars[var_name] - shape = None - if dist_context is not None: - dist_tensor = dist_context.get_dist_tensor_for_program(var) - shape = dist_tensor.local_sizes() - else: - shape = var.shape - assert shape is not None + var = get_var_with_recursion(var_name, op.block, op.block.program) + shape = var.shape var_desc.append((var.dtype, shape)) output_desc[out_name] = var_desc desc["outputs"] = output_desc @@ -82,19 +69,101 @@ def _parse_op_to_desc(op, dist_context=None): return desc -def parse_to_desc(op=None, dist_op=None, dist_context=None): - desc = None - if op is None and dist_op is not None and dist_context is not None: - desc = _parse_op_to_desc(op=dist_op.serial_op, - dist_context=dist_context) - elif op is not None and dist_op is None and dist_context is None: - desc = _parse_op_to_desc(op) - - return desc - - -def parse_desc_to_str(desc): - +def build_comp_desc_from_dist_op(dist_op, dist_context): + """Build descriptions of computation op distributed on the processes.""" + from ..reshard import get_var_with_recursion + + op_descs = {} + op = dist_op.serial_op + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + assert process_mesh, "Process mesh must not be None." + processes = process_mesh.processes + for process in processes: + desc = {} + desc["op"] = op.type + attr_desc = op.all_attrs() + # NOTE: The attrs of desc is replica of serial op, there may be a bug if shape need to be partitioned involved in attrs. + desc["attrs"] = attr_desc + input_desc = OrderedDict() + output_desc = OrderedDict() + + # Get partitioned shape of input + for input_name in op.input_names: + var_name_list = op.input(input_name) + var_desc = [] + for var_name in var_name_list: + var = get_var_with_recursion(var_name, op.block, + op.block.program) + # Use op input_dims_mapping + dims_mapping = dist_attr.get_input_dims_mapping(var_name) + global_sizes = var.shape + # NOTE: When support uneven partition, the shard_sizes will be got from dist_attr. + shard_sizes = None + topology = process_mesh.topology + shape = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, process, + shard_sizes) + var_desc.append((var.dtype, shape)) + + # For special op such as embedding and its grad op + if op.type == "c_embedding" or op.type == "lookup_table_v2" or op.type == "c_embedding_grad" or op.type == "lookup_table_v2_grad": + if input_name == "W": + embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( + op.input(input_name)[0])[0] + relative_idx = _get_idx_in_axis( + processes, dist_attr.process_mesh.topology, + embedding_row_dim_mapping, process) + per_part_size = shape[0] + relative_idx = relative_idx * per_part_size + desc["attrs"]["start_index"] = relative_idx + + input_desc[input_name] = var_desc + desc["inputs"] = input_desc + + for out_name in op.output_names: + var_name_list = op.output(out_name) + var_desc = [] + for var_name in var_name_list: + # Use op output_dims_mapping + var = get_var_with_recursion(var_name, op.block, + op.block.program) + dist_attr = dist_op.dist_attr + dims_mapping = dist_attr.get_output_dims_mapping(var_name) + process_mesh = dist_attr.process_mesh + global_sizes = var.shape + shard_sizes = None + processes = process_mesh.processes + topology = process_mesh.topology + shape = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, process, + shard_sizes) + var_desc.append((var.dtype, shape)) + + # For special op such as fill_constant_batch_size_like + if op.type == "fill_constant_batch_size_like": + # Modify shape attr according to how output are partitioned + out_name = var_name_list[0] + dims_mapping = dist_attr.get_output_dims_mapping(out_name) + process_mesh_shape = dist_attr.process_mesh.topology + shape_list = op.attr("shape") + # Modify target shape + for idx, axis in enumerate(dims_mapping): + if axis >= 0: + shape_list[idx] = shape_list[ + idx] // process_mesh_shape[axis] + desc["attrs"]["shape"] = shape_list + output_desc[out_name] = var_desc + + desc["outputs"] = output_desc + + op_descs[process] = desc + + return op_descs + + +def build_comp_desc_str_for_predict(desc): + # NOTE: The description format may change in the future. def _parse_dtype(dtype): dtype_str = "" if dtype == paddle.float32: @@ -135,8 +204,208 @@ def parse_desc_to_str(desc): shape_str = "[" + ",".join(shape_list) + "]" desc_str_list += [dtype_str, dims_str, shape_str] desc_str = "_".join(desc_str_list) + attrs = desc["attrs"] + parse_result = (desc_str, attrs) + return parse_result + + +def build_comm_desc_from_dist_op(op_type, + dist_op, + ctx, + var_names, + attrs=None, + parallel_axis=None, + group_ranks=None): + """Build descriptions of communication op distributed on the processes.""" + from ..reshard import get_var_with_recursion + + specific_op_type = [] + dist_attr = dist_op.dist_attr + assert dist_attr, "Dist attr must not be None." + process_mesh = dist_attr.process_mesh + assert process_mesh, "Process mesh must not be None." + + processes = process_mesh.processes + op_descs = {} + for process in processes: + rank_id = process + desc = {} + desc["op"] = op_type + op_attrs = None + comm_group_ranks = None + + if op_type not in specific_op_type: + serial_op = dist_op.serial_op + input_list = [] + # The var_names usually contain just one item. + for var_name in var_names: + dist_attr = dist_op.dist_attr + has_found = False + # Find var_name in serial op input or output + for name in dist_op.serial_op.input_arg_names: + # If a tensor is the input of multi ops, sum the grad of all ops, so the name will be varname@RENAME@block@0 and so on. + if var_name in name: + var_name = name + has_found = True + break + + if not has_found: + for name in dist_op.serial_op.output_arg_names: + if var_name in name: + var_name = name + has_found = True + break + assert has_found + var = get_var_with_recursion(var_name, serial_op.block, + serial_op.block.program) + + dims_mapping = dist_attr.get_input_dims_mapping( + var_name + ) if var_name in dist_op.serial_op.input_arg_names else dist_attr.get_output_dims_mapping( + var_name) + global_sizes = var.shape + shard_sizes = None + topology = process_mesh.topology + shape = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, process, + shard_sizes) + input_list.append((var.dtype, shape)) + + # NOTE: The input_name of comm ops used usually is X. + desc["inputs"] = {"X": input_list} + + # Get comm group by parallel_axis or the given group_ranks. + if parallel_axis is not None: + process_mesh_shape = process_mesh.topology + process_mesh_group = process_mesh.processes + comm_group_ranks = _get_comm_group(process_mesh_group, + process_mesh_shape, + parallel_axis, rank_id) + elif group_ranks is not None: + comm_group_ranks = group_ranks + else: + raise ValueError( + "The parallel_axis and group_ranks can not be None in the same." + ) + + if attrs is not None: + assert isinstance(attrs, dict) + op_attrs = attrs + else: + op_attrs = {} + + desc["attrs"] = op_attrs + desc["group_ranks"] = comm_group_ranks + + op_descs[rank_id] = desc + + return op_descs + + +def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None): + """Build a comm desc directly.""" + desc = {} + desc["op"] = op_type + desc["group_ranks"] = group_ranks + desc["inputs"] = {"X": [(dtype, shape)]} + desc["attrs"] = attrs + return desc + - return desc_str +def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): + """Build comm costs by descriptions""" + comm_context = CommContext(cluster) + group_ranks_list = [] + comm_op_cost_list = [] + for process in processes: + desc = descs[process] + group_ranks = desc["group_ranks"] + if group_ranks not in group_ranks_list: + group_ranks_list.append(group_ranks) + comm_op_cost = op_cost_class(op_desc=desc, + comm_context=comm_context) + comm_op_cost_list.append(comm_op_cost) + return comm_op_cost_list + + +def build_comp_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): + """Build comp costs by descriptions.""" + costs = {} + for process in processes: + costs[process] = op_cost_class(op_desc=descs[process], cluster=cluster) + return costs + + +def build_dp_costs(result, dist_op, ctx, var_names, attrs, parallel_axis, + cluster): + """DP cost contains a allreduce_sum op cost and a scale op cost""" + # The costs will be appended in the given result. + from ..reshard import get_var_with_recursion + + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + assert len(var_names) == 1 + vars = dist_op.serial_op.block.vars + var_name = var_names[0] + has_found = False + for name in dist_op.serial_op.input_arg_names: + if var_name in name: + var_name = name + has_found = True + break + + if not has_found: + for name in dist_op.serial_op.output_arg_names: + if var_name in name: + var_name = name + has_found = True + break + if not has_found: + return + + c_allreduce_sum_descs = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + comm_cost_list = build_comm_costs_from_descs( + _g_op_cost_factory["c_allreduce_sum"], ctx, processes, + c_allreduce_sum_descs, cluster) + result.append(comm_cost_list) + + # The scale op just on the group_ranks + for comm_cost in comm_cost_list: + group_ranks = comm_cost.group_ranks + dp_degree = len(group_ranks) + scale_costs = {} + op_type = "scale" + for rank in group_ranks: + desc = {} + desc["op"] = op_type + desc["inputs"] = {} + dims_mapping = dist_attr.get_input_dims_mapping( + var_name) if dist_attr.get_input_dims_mapping( + var_name + ) is not None else dist_attr.get_output_dims_mapping(var_name) + var = get_var_with_recursion(var_name, dist_op.serial_op.block, + dist_op.serial_op.block.program) + global_sizes = var.shape + shard_sizes = None + topology = process_mesh.topology + shape = DistributedTensor.get_local_sizes(global_sizes, + dims_mapping, topology, + processes, rank, + shard_sizes) + desc["inputs"]["X"] = [(var.dtype, shape)] + attrs = {"scale": 1.0 / dp_degree} + desc["attrs"] = attrs + scale_op_cost = _g_op_cost_factory["scale"](op_desc=desc, + cluster=cluster) + scale_costs[rank] = scale_op_cost + result.append(scale_costs) class CommContext: @@ -174,6 +443,8 @@ class CommContext: # set default self.base_ring = 8.4 self.base_tree = 0. + # self.base_inter_ring = 9.6 + # self.base_inter_tree = 28 # NVL in default self.intra_ring = 3.4 self.intra_tree = 28 @@ -441,6 +712,8 @@ class CommOpCost(OpCost): @property def comm_count(self): + from ..reshard import get_var_with_recursion + if self._comm_count is None: dtype = None shape = None @@ -448,7 +721,8 @@ class CommOpCost(OpCost): vars = self.op.block.vars # NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided var_name = self.op.input("X")[0] - var = vars[var_name] + var = get_var_with_recursion(var_name, self.op.block, + self.program) dtype = var.dtype shape = var.shape elif self.op_desc is not None: @@ -464,9 +738,10 @@ class CommOpCost(OpCost): factor = 1 elif dtype == paddle.float16: factor = 2 + elif dtype == paddle.bool: + factor = 8 else: - raise TypeError( - "This dtype {} is not supported now".format(dtype)) + raise ValueError("Unsupported comm dtype {}".format(dtype)) comm_count = reduce(lambda x, y: x * y, shape) * factor self._comm_count = comm_count diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 5738412dd52..6c51ce1fffa 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -51,6 +51,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS}) + py_test_modules(test_base_cost MODULES test_base_cost ENVS ${dist_ENVS}) py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS}) py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS}) py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py new file mode 100644 index 00000000000..0fbe4f5bd3d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py @@ -0,0 +1,234 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import os +import json +import tempfile + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.reshard import Resharder +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.cost import CommContext +from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_dist_op +from paddle.distributed.auto_parallel.cost.base_cost import build_comm_desc_from_dist_op +from paddle.distributed.auto_parallel.cost.base_cost import build_comm_costs_from_descs +from paddle.distributed.auto_parallel.cost.base_cost import build_comp_costs_from_descs +from paddle.distributed.auto_parallel.cost.base_cost import build_dp_costs +from paddle.distributed.auto_parallel.cost import AllreduceSumOpCost +from paddle.distributed.auto_parallel.cost import _g_op_cost_factory +from test_cluster import cluster_json + +paddle.enable_static() +_global_parallel_strategy = "dp_mp_pp" +_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear(d_model, + dim_feedforward, + weight_attr, + bias_attr=bias_attr) + self.linear1 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + auto.shard_tensor(self.linear0.weight, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, 1] + }) + auto.shard_tensor(self.linear1.weight, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [1, -1] + }) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data(name="input", + shape=[batch_size, hidden_size], + dtype='float32') + label = static.data(name="label", + shape=[batch_size, 1], + dtype='float32') + + fill_constant_out = paddle.fluid.layers.fill_constant_batch_size_like( + input=input, shape=[batch_size], value=1, dtype="int32") + embedding = paddle.nn.Embedding(10, hidden_size, sparse=True) + embedding_out = embedding(fill_constant_out) + + auto.shard_tensor(input, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [0, -1] + }) + auto.shard_tensor(label, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [0, -1] + }) + + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(embedding_out) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def get_prog(train_program, startup_program, dist_context, rank_id): + global _global_process_mesh + dist_context.process_mesh = _global_process_mesh + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) + dist_context.block_state.parse_forward_blocks(complete_train_program) + params_grads = parallelizer._generate_backward(complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + return train_program, startup_program, params_grads + + +class TestBaseCost(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_base_cost(self): + # Build cluster + cluster_json_path = os.path.join(self.temp_dir.name, + "auto_parallel_cluster.json") + cluster_json_object = json.loads(cluster_json) + with open(cluster_json_path, "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file(cluster_json_path) + + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 2 + train_program, startup_program, params_grads = get_prog( + train_program, startup_program, dist_context, rank_id) + + for op in train_program.global_block().ops: + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op: + processes = dist_op.dist_attr.process_mesh.processes + comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context) + self.assertTrue(isinstance(comp_descs, dict) and comp_descs) + var_names = None + if op.input_arg_names: + var_names = op.input_arg_names[0] + comm_descs = build_comm_desc_from_dist_op("c_allreduce_sum", + dist_op, + dist_context, + var_names, + attrs=None, + parallel_axis=0, + group_ranks=None) + self.assertTrue(isinstance(comm_descs, dict) and comm_descs) + comm_descs = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + dist_context, + var_names, + attrs=None, + parallel_axis=None, + group_ranks=processes) + self.assertTrue(isinstance(comm_descs, dict) and comm_descs) + + comm_costs = build_comm_costs_from_descs( + AllreduceSumOpCost, dist_context, processes, comm_descs, + cluster) + self.assertTrue(comm_costs) + + comp_costs = build_comp_costs_from_descs( + _g_op_cost_factory[op.type], dist_context, processes, + comp_descs, cluster) + self.assertTrue(comp_costs) + + result = [] + build_dp_costs(result, dist_op, dist_context, var_names[0], + None, 0, cluster) + self.assertTrue(result) + + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster.py index dd9b0110dbe..641ca38b649 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster.py @@ -2018,6 +2018,10 @@ class TestCluster(unittest.TestCase): self.assertTrue(devices == [5, 6, 7, 10]) self.assertTrue(involved_machine_count == 2) + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py index 21538578788..5744cf6d392 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py @@ -154,6 +154,10 @@ class TestCommOpCost(unittest.TestCase): comm_context=comm_context) self.assertTrue(recv_op_cost.time > 0) + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py index fe461312257..6b0db61b984 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py @@ -19,8 +19,8 @@ import tempfile import paddle import paddle.distributed.auto_parallel.cost as cost_model -from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc -from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str +from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_op +from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_str_for_predict from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.cost import CommContext @@ -60,8 +60,8 @@ class TestCost(unittest.TestCase): break matmul_v2_cost = cost_model._g_op_cost_factory["matmul_v2"]( op=matmul_v2_op) - desc = parse_to_desc(op=matmul_v2_op) - desc_str = parse_desc_to_str(desc) + desc = build_comp_desc_from_op(op=matmul_v2_op) + desc_str = build_comp_desc_str_for_predict(desc) self.assertIsNotNone(desc_str) self.assertTrue(check_cost(matmul_v2_cost.cost)) time = calc_time_by_modeling(op=matmul_v2_op) @@ -92,11 +92,29 @@ class TestCost(unittest.TestCase): op_desc=desc, comm_context=CommContext(cluster)) self.assertTrue(check_cost(allreduce_cost.cost)) + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + def test_cost_estimator(self): + # Build cluster + cluster_json_path = os.path.join(self.temp_dir.name, + "auto_parallel_cluster.json") + cluster_json_object = json.loads(cluster_json) + with open(cluster_json_path, "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file(cluster_json_path) + train_program = paddle.static.Program() - cost_estimator = cost_model.CostEstimator(train_program) + cost_estimator = cost_model.CostEstimator(train_program, + cluster=cluster) self.assertIsNotNone(cost_estimator) + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + if __name__ == "__main__": unittest.main() -- GitLab