提交 464f6123 编写于 作者: C ceci3

update multi search space

上级 9fdb6054
......@@ -14,6 +14,8 @@
import mobilenetv2
from .mobilenetv2 import *
import resnet
from .resnet import *
import search_space_registry
from search_space_registry import *
import search_space_factory
......@@ -26,3 +28,4 @@ __all__ += mobilenetv2.__all__
__all__ += search_space_registry.__all__
__all__ += search_space_factory.__all__
__all__ += search_space_base.__all__
......@@ -59,5 +59,7 @@ def conv_bn_layer(input,
moving_variance_name=bn_name + '_variance')
if act == 'relu6':
return fluid.layers.relu6(bn)
elif act == 'sigmoid':
return fluid.layers.sigmoid(bn)
else:
return bn
......@@ -86,9 +86,5 @@ class CombineSearchSpace(object):
for space, token in zip(self.spaces, self.token):
model_archs.append(space.token2arch(token))
def net_arch(input):
for model_arch in model_archs:
input = model_arch(input)
return input
return model_archs
return net_arch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
__all__ = ["ResNetSpace"]
@SEARCHSPACE.register
class ResNetSpace(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000):
super(ResNetSpace, self).__init__(input_size, output_size, block_num)
pass
def init_tokens(self):
return [0,0,0,0,0,0]
def range_table(self):
return [3,3,3,3,3,3]
def token2arch(self,tokens=None):
if tokens is None:
self.init_tokens()
def net_arch(input):
input = conv_bn_layer(
input,
num_filters=32,
filter_size=3,
stride=2,
padding='SAME',
act='sigmoid',
name='resnet_conv1_1')
return input
return net_arch
......@@ -24,7 +24,7 @@ class TestSearchSpaceFactory(unittest.TestCase):
config = {'input_size': 224, 'output_size': 7, 'block_num': 5}
space = SearchSpaceFactory()
my_space = space.get_search_space('MobileNetV2Space', config)
my_space = space.get_search_space([('MobileNetV2Space', config)])
model_arch = my_space.token2arch()
train_prog = fluid.Program()
......@@ -36,10 +36,26 @@ class TestSearchSpaceFactory(unittest.TestCase):
shape=[1, 3, input_size, input_size],
dtype='float32',
append_batch_size=False)
print('input shape', model_input.shape)
predict = model_arch(model_input)
print('output shape', predict.shape)
predict = model_arch[0](model_input)
self.assertTrue(predict.shape[2] == config['output_size'])
class TestMultiSearchSpace(unittest.TestCase):
space = SearchSpaceFactory()
config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
my_space = space.get_search_space([('MobileNetV2Space', config0), ('ResNetSpace', config1)])
model_archs = my_space.token2arch()
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_size= config0['input_size']
model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False)
for model_arch in model_archs:
predict = model_arch(model_input)
model_input = predict
print(predict)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册