search.py 4.3 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from collections import namedtuple
from ...analysis.flops import dygraph_flops
from ...analysis.latency import TableLatencyEvaluator

ConstraintConfig = namedtuple(
    'ConstraintConfig',
    ['acc_constraint', 'latency_constraint', 'flops_constraint'])
ConstraintConfig.__new__.__defaults__ = (None, ) * len(ConstraintConfig._fields)


class BaseNetConfig:
    def __init__(self):
        raise NotImplementedError("NotImplemented")

    def random_choice(self):
        raise NotImplementedError("NotImplemented")


class EvolutionSearch:
    def __init__(self, net_config, constraint, strategy='EVO', **kwargs):
        assert isinstance(
            constraint,
            ContraintConfig), "constraint must be instance of ContraintConfig"
        assert issubclass(net_config, BaseNetConfig)

        self.net_config = net_config

        if strategy == 'EVO':
            self.strategy = Evolution(**kwargs)
        else:
            raise NotImplementedError("strategy not Implement")

        for key, value in constraint.items():
            setattr(self, key, value)

        self.population_size = getattr(kwargs, 'population_size', 100)
        self.mutate_prob = getattr(kwargs, 'mutate_prob', 0.1)
        self.evo_iter = getattr(kwargs, 'evo_iter', 500)
        self.parent_ratio = getattr(kwargs, 'parent_ratio', 0.25)
        self.mutation_ratio = getattr(kwargs, 'mutation_ratio', 0.5)

        if self.acc_constraint != None:
            input_dim = getattr(self.acc_constraint, 'input_dim', 128)
            pred_model = getattr(self.acc_constraint, 'pred_model', None)
            self.acc_predicter = AccuracyEvaluator(pred_model, input_dim)
            self.min_acc = getattr(self.acc_constraint, 'min_acc', 1.0)

        if self.latency_constraint != None:
            table_file = getattr(self.latency_constraint, 'table_file', None)
            assert table_file != None
            self.latency_predicter = TableLatencyEvaluator(table_file)
            self.max_latency = getattr(self.latency_constraint, 'max_latency',
                                       -1)

        if self.flops_constraint != None:
            self.flops_predicter = dygraph_flops

    def start_search(self):
        mutation_size = int(round(self.population_size * self.mutation_ratio))
        parents_size = int(round(self.population_size * self.parent_ratio))
        best_valid = [-100]

        population = self.random_sample(self.population_size)
        for i in range(self.evo_iter):
            pass

    def satify_constraint(self, sample):
        status = {}
        if self.acc_constraint != None:
            cur_acc = self.acc_predicter(sample)
            if cur_acc < self.min_acc:
                return False, None
            status['acc'] = cur_acc

        if self.latency_constraint != None:
            net = self.convert_onehot_to_net(sample)
            cur_latency = self.latency_predicter.latency(net)
            if cur_latency < self.max_latency:
                return False, None
            status['latency'] = cur_latency

        if self.flops_constraint != None:
            net = self.convert_onehot_to_net(sample)
            cur_flops = self.flops_predicter(net)
            if cur_flops > self.flops_constraint:
                return False, None
            status['flops'] = cur_flops

        return True, status

    def random_sample(self, sample_size=1):
        population = []
        while len(population) < sample_size:
            sample = self.net_config.random_choice()
            satify, constraint_status = self.satify_constraint(sample)
            if satify:
                population.append((sample, constraint_status))

        return population

    def mutate_sample(self, sample):
        pass

    def crossover_sample(self, sample1, sample2):
        pass