estimate_cost.py 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License


class CostEstimator:
17

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    def __init__(self,
                 program,
                 cluster=None,
                 dist_context=None,
                 mode="modeling"):
        self._program = program
        self._cluster = cluster
        self._dist_context = dist_context
        self._check_mode(mode)
        self._mode = mode
        self._global_cost = None
        self._local_cost = {}

    @property
    def program(self):
        return self._program

    @property
    def dist_context(self):
        return self._dist_context

    @property
    def cluster(self):
        return self._cluster

    @property
    def mode(self):
        return self._mode

    @property
    def global_cost(self):
        return self._global_cost

    @property
    def local_cost(self):
        return self._local_cost

    def get_op_cost(self):
        return 0

    def get_tensor_cost(self):
        return 0

    def get_global_cost(self):
        return 0

    def get_local_cost(self, rank=None):
        return 0

    def _check_mode(self, mode):
        if mode not in ["modeling", "profiling"]:
            raise ValueError(
                "Just support modeling and profiling, but got {}".format(mode))