From 5378c16371c5ec3821a2f70af5f71ff166000dbe Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Mon, 21 Sep 2020 12:48:23 +0000 Subject: [PATCH] add evo_search --- paddleslim/nas/__init__.py | 2 + paddleslim/nas/common/__init__.py | 15 ++++ paddleslim/nas/common/predict.py | 50 +++++++++++++ paddleslim/nas/ofa/search.py | 120 ++++++++++++++++++++++++++++++ 4 files changed, 187 insertions(+) create mode 100644 paddleslim/nas/common/__init__.py create mode 100644 paddleslim/nas/common/predict.py create mode 100644 paddleslim/nas/ofa/search.py diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index d438e54c..0df651c5 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -19,6 +19,8 @@ from .sa_nas import * from .rl_nas import * from ..nas import darts from .darts import * +from .ofa import * +from .common import * __all__ = [] __all__ += sa_nas.__all__ diff --git a/paddleslim/nas/common/__init__.py b/paddleslim/nas/common/__init__.py new file mode 100644 index 00000000..2c5fd315 --- /dev/null +++ b/paddleslim/nas/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .predict import AccuracyEvaluator diff --git a/paddleslim/nas/common/predict.py b/paddleslim/nas/common/predict.py new file mode 100644 index 00000000..3256715c --- /dev/null +++ b/paddleslim/nas/common/predict.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.fluid as fluid + + +class AccuracyEvaluator: + def __init__(self, model=None, input_dim=128): + if model == None: + self.model = DefaultModel(input_dim=input_dim) + else: + assert isinstance(model, fluid.dygraph.Layer) + self.model = model + + @fluid.dygraph.no_grad + def predict_accuracy(self, net_arch): + pred = self.model(net_arch) + return pred.numpy() + + def convert_net_to_onehot(self, net): + pass + + def convert_onehot_to_net(self, net_onehot): + pass + + +class DefaultModel(fluid.dygraph.Layer): + def __init__(self, input_dim): + super(Model, self).__init__() + self.models = nn.Sequential( + nn.Linear(input_dim, 400), + nn.ReLU(), + nn.Linear(400, 400), + nn.ReLU(), nn.Linear(400, 400), nn.ReLU(), nn.Linear(400, 1)) + + def forward(self, *inputs, **kwargs): + return self.model(inputs) diff --git a/paddleslim/nas/ofa/search.py b/paddleslim/nas/ofa/search.py new file mode 100644 index 00000000..e7128e17 --- /dev/null +++ b/paddleslim/nas/ofa/search.py @@ -0,0 +1,120 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from collections import namedtuple +from ...analysis.flops import dygraph_flops +from ...analysis.latency import TableLatencyEvaluator + +ConstraintConfig = namedtuple( + 'ConstraintConfig', + ['acc_constraint', 'latency_constraint', 'flops_constraint']) +ConstraintConfig.__new__.__defaults__ = (None, ) * len(ConstraintConfig._fields) + + +class BaseNetConfig: + def __init__(self): + raise NotImplementedError("NotImplemented") + + def random_choice(self): + raise NotImplementedError("NotImplemented") + + +class EvolutionSearch: + def __init__(self, net_config, constraint, strategy='EVO', **kwargs): + assert isinstance( + constraint, + ContraintConfig), "constraint must be instance of ContraintConfig" + assert issubclass(net_config, BaseNetConfig) + + self.net_config = net_config + + if strategy == 'EVO': + self.strategy = Evolution(**kwargs) + else: + raise NotImplementedError("strategy not Implement") + + for key, value in constraint.items(): + setattr(self, key, value) + + self.population_size = getattr(kwargs, 'population_size', 100) + self.mutate_prob = getattr(kwargs, 'mutate_prob', 0.1) + self.evo_iter = getattr(kwargs, 'evo_iter', 500) + self.parent_ratio = getattr(kwargs, 'parent_ratio', 0.25) + self.mutation_ratio = getattr(kwargs, 'mutation_ratio', 0.5) + + if self.acc_constraint != None: + input_dim = getattr(self.acc_constraint, 'input_dim', 128) + pred_model = getattr(self.acc_constraint, 'pred_model', None) + self.acc_predicter = AccuracyEvaluator(pred_model, input_dim) + self.min_acc = getattr(self.acc_constraint, 'min_acc', 1.0) + + if self.latency_constraint != None: + table_file = getattr(self.latency_constraint, 'table_file', None) + assert table_file != None + self.latency_predicter = TableLatencyEvaluator(table_file) + self.max_latency = getattr(self.latency_constraint, 'max_latency', + -1) + + if self.flops_constraint != None: + self.flops_predicter = dygraph_flops + + def start_search(self): + mutation_size = int(round(self.population_size * self.mutation_ratio)) + parents_size = int(round(self.population_size * self.parent_ratio)) + best_valid = [-100] + + population = self.random_sample(self.population_size) + for i in range(self.evo_iter): + pass + + def satify_constraint(self, sample): + status = {} + if self.acc_constraint != None: + cur_acc = self.acc_predicter(sample) + if cur_acc < self.min_acc: + return False, None + status['acc'] = cur_acc + + if self.latency_constraint != None: + net = self.convert_onehot_to_net(sample) + cur_latency = self.latency_predicter.latency(net) + if cur_latency < self.max_latency: + return False, None + status['latency'] = cur_latency + + if self.flops_constraint != None: + net = self.convert_onehot_to_net(sample) + cur_flops = self.flops_predicter(net) + if cur_flops > self.flops_constraint: + return False, None + status['flops'] = cur_flops + + return True, status + + def random_sample(self, sample_size=1): + population = [] + while len(population) < sample_size: + sample = self.net_config.random_choice() + satify, constraint_status = self.satify_constraint(sample) + if satify: + population.append((sample, constraint_status)) + + return population + + def mutate_sample(self, sample): + pass + + def crossover_sample(self, sample1, sample2): + pass -- GitLab