test_new_cost_model.py 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import json
16
import os
17
import tempfile
18 19 20
import unittest

from test_cluster import cluster_json
21 22 23

import paddle
import paddle.distributed.auto_parallel.cost as cost_model
24 25
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CommContext
26 27 28 29 30
from paddle.distributed.auto_parallel.cost.base_cost import (
    build_comp_desc_from_op,
    build_comp_desc_str_for_predict,
    calc_time_by_modeling,
)
31 32 33 34 35 36 37 38 39 40 41

paddle.enable_static()


def check_cost(cost):
    if cost.memory >= 0 and cost.flops >= 0 and cost.time >= 0:
        return True
    return False


class TestCost(unittest.TestCase):
42 43 44 45 46 47
    def setUp(self):
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDown(self):
        self.temp_dir.cleanup()

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    def test_base_cost(self):
        cost = cost_model.Cost(memory=100, flops=200, time=0.5)
        self.assertTrue(check_cost(cost))

    def test_comp_cost(self):
        x = paddle.static.data(name="x", shape=[20, 20], dtype='float32')
        y = paddle.static.data(name="y", shape=[20, 20], dtype='float32')

        z = paddle.matmul(x, y)
        matmul_v2_op = None
        ops = paddle.static.default_main_program().global_block().ops
        for op in ops:
            if op.type == "matmul_v2":
                matmul_v2_op = op
                break
C
caozhou 已提交
63
        matmul_v2_cost = cost_model._g_op_cost_factory["matmul_v2"](
64 65
            op=matmul_v2_op
        )
66 67
        desc = build_comp_desc_from_op(op=matmul_v2_op)
        desc_str = build_comp_desc_str_for_predict(desc)
68 69
        self.assertIsNotNone(desc_str)
        self.assertTrue(check_cost(matmul_v2_cost.cost))
C
caozhou 已提交
70
        time = calc_time_by_modeling(op=matmul_v2_op)
71 72 73 74 75 76
        self.assertEqual(time, matmul_v2_cost.cost.time)
        tensor_cost = cost_model.TensorCost(tensor=x)
        # check memory
        self.assertEqual(tensor_cost.cost.memory, 1600)

    def test_comm_cost(self):
77
        # Build cluster
78 79 80
        cluster_json_path = os.path.join(
            self.temp_dir.name, "auto_parallel_cluster.json"
        )
81 82 83 84 85 86 87 88 89 90
        cluster_json_object = json.loads(cluster_json)
        with open(cluster_json_path, "w") as cluster_json_file:
            json.dump(cluster_json_object, cluster_json_file)
        cluster = Cluster()
        cluster.build_from_file(cluster_json_path)

        # Build CommConetxt
        CommContext._has_instance = None
        CommContext._instance = None
        comm_context = CommContext(cluster)
91 92
        desc = {}
        desc["op"] = "c_allreduce_sum"
93
        desc["inputs"] = {"X": [(paddle.float32, [100, 200])]}
C
caozhou 已提交
94 95
        desc["group_ranks"] = [0, 1]
        allreduce_cost = cost_model._g_op_cost_factory["c_allreduce_sum"](
96 97
            op_desc=desc, comm_context=CommContext(cluster)
        )
98 99
        self.assertTrue(check_cost(allreduce_cost.cost))

100 101 102 103
        # Remove unnecessary files
        if os.path.exists(cluster_json_path):
            os.remove(cluster_json_path)

104
    def test_cost_estimator(self):
105
        # Build cluster
106 107 108
        cluster_json_path = os.path.join(
            self.temp_dir.name, "auto_parallel_cluster.json"
        )
109 110 111 112 113 114
        cluster_json_object = json.loads(cluster_json)
        with open(cluster_json_path, "w") as cluster_json_file:
            json.dump(cluster_json_object, cluster_json_file)
        cluster = Cluster()
        cluster.build_from_file(cluster_json_path)

115
        train_program = paddle.static.Program()
116 117 118
        cost_estimator = cost_model.CostEstimator(
            train_program, cluster=cluster
        )
119 120
        self.assertIsNotNone(cost_estimator)

121 122 123 124
        # Remove unnecessary files
        if os.path.exists(cluster_json_path):
            os.remove(cluster_json_path)

125 126 127

if __name__ == "__main__":
    unittest.main()