提交 5378c163 编写于 作者: C ceci3

add evo_search

上级 c67e3f88
...@@ -19,6 +19,8 @@ from .sa_nas import * ...@@ -19,6 +19,8 @@ from .sa_nas import *
from .rl_nas import * from .rl_nas import *
from ..nas import darts from ..nas import darts
from .darts import * from .darts import *
from .ofa import *
from .common import *
__all__ = [] __all__ = []
__all__ += sa_nas.__all__ __all__ += sa_nas.__all__
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .predict import AccuracyEvaluator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
class AccuracyEvaluator:
def __init__(self, model=None, input_dim=128):
if model == None:
self.model = DefaultModel(input_dim=input_dim)
else:
assert isinstance(model, fluid.dygraph.Layer)
self.model = model
@fluid.dygraph.no_grad
def predict_accuracy(self, net_arch):
pred = self.model(net_arch)
return pred.numpy()
def convert_net_to_onehot(self, net):
pass
def convert_onehot_to_net(self, net_onehot):
pass
class DefaultModel(fluid.dygraph.Layer):
def __init__(self, input_dim):
super(Model, self).__init__()
self.models = nn.Sequential(
nn.Linear(input_dim, 400),
nn.ReLU(),
nn.Linear(400, 400),
nn.ReLU(), nn.Linear(400, 400), nn.ReLU(), nn.Linear(400, 1))
def forward(self, *inputs, **kwargs):
return self.model(inputs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from collections import namedtuple
from ...analysis.flops import dygraph_flops
from ...analysis.latency import TableLatencyEvaluator
ConstraintConfig = namedtuple(
'ConstraintConfig',
['acc_constraint', 'latency_constraint', 'flops_constraint'])
ConstraintConfig.__new__.__defaults__ = (None, ) * len(ConstraintConfig._fields)
class BaseNetConfig:
def __init__(self):
raise NotImplementedError("NotImplemented")
def random_choice(self):
raise NotImplementedError("NotImplemented")
class EvolutionSearch:
def __init__(self, net_config, constraint, strategy='EVO', **kwargs):
assert isinstance(
constraint,
ContraintConfig), "constraint must be instance of ContraintConfig"
assert issubclass(net_config, BaseNetConfig)
self.net_config = net_config
if strategy == 'EVO':
self.strategy = Evolution(**kwargs)
else:
raise NotImplementedError("strategy not Implement")
for key, value in constraint.items():
setattr(self, key, value)
self.population_size = getattr(kwargs, 'population_size', 100)
self.mutate_prob = getattr(kwargs, 'mutate_prob', 0.1)
self.evo_iter = getattr(kwargs, 'evo_iter', 500)
self.parent_ratio = getattr(kwargs, 'parent_ratio', 0.25)
self.mutation_ratio = getattr(kwargs, 'mutation_ratio', 0.5)
if self.acc_constraint != None:
input_dim = getattr(self.acc_constraint, 'input_dim', 128)
pred_model = getattr(self.acc_constraint, 'pred_model', None)
self.acc_predicter = AccuracyEvaluator(pred_model, input_dim)
self.min_acc = getattr(self.acc_constraint, 'min_acc', 1.0)
if self.latency_constraint != None:
table_file = getattr(self.latency_constraint, 'table_file', None)
assert table_file != None
self.latency_predicter = TableLatencyEvaluator(table_file)
self.max_latency = getattr(self.latency_constraint, 'max_latency',
-1)
if self.flops_constraint != None:
self.flops_predicter = dygraph_flops
def start_search(self):
mutation_size = int(round(self.population_size * self.mutation_ratio))
parents_size = int(round(self.population_size * self.parent_ratio))
best_valid = [-100]
population = self.random_sample(self.population_size)
for i in range(self.evo_iter):
pass
def satify_constraint(self, sample):
status = {}
if self.acc_constraint != None:
cur_acc = self.acc_predicter(sample)
if cur_acc < self.min_acc:
return False, None
status['acc'] = cur_acc
if self.latency_constraint != None:
net = self.convert_onehot_to_net(sample)
cur_latency = self.latency_predicter.latency(net)
if cur_latency < self.max_latency:
return False, None
status['latency'] = cur_latency
if self.flops_constraint != None:
net = self.convert_onehot_to_net(sample)
cur_flops = self.flops_predicter(net)
if cur_flops > self.flops_constraint:
return False, None
status['flops'] = cur_flops
return True, status
def random_sample(self, sample_size=1):
population = []
while len(population) < sample_size:
sample = self.net_config.random_choice()
satify, constraint_status = self.satify_constraint(sample)
if satify:
population.append((sample, constraint_status))
return population
def mutate_sample(self, sample):
pass
def crossover_sample(self, sample1, sample2):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册