未验证 提交 14ff5175 编写于 作者: T Taehoon Lee 提交者: GitHub

Fix NASNet (#10209)

* Fix NASNet

* Update weight files
上级 b8ac7e07
......@@ -60,10 +60,10 @@ from ..applications.inception_v3 import preprocess_input
from ..applications.imagenet_utils import decode_predictions
from .. import backend as K
NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile.h5'
NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile-no-top.h5'
NASNET_LARGE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large.h5'
NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large-no-top.h5'
NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-mobile.h5'
NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-mobile-no-top.h5'
NASNET_LARGE_WEIGHT_PATH = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-large.h5'
NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-large-no-top.h5'
def NASNet(input_shape=None,
......@@ -102,7 +102,7 @@ def NASNet(input_shape=None,
- P is the number of penultimate filters
stem_block_filters: Number of filters in the initial stem block
skip_reduction: Whether to skip the reduction step at the tail
end of the network. Set to `False` for CIFAR models.
end of the network.
filter_multiplier: Controls the width of the network.
- If `filter_multiplier` < 1.0, proportionally decreases the number
of filters in each layer.
......@@ -210,24 +210,18 @@ def NASNet(input_shape=None,
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
filters = penultimate_filters // 24
if not skip_reduction:
x = Conv2D(stem_block_filters, (3, 3), strides=(2, 2), padding='valid',
use_bias=False, name='stem_conv1',
kernel_initializer='he_normal')(img_input)
else:
x = Conv2D(stem_block_filters, (3, 3), strides=(1, 1), padding='same',
use_bias=False, name='stem_conv1',
kernel_initializer='he_normal')(img_input)
x = Conv2D(stem_block_filters, (3, 3), strides=(2, 2), padding='valid',
use_bias=False, name='stem_conv1',
kernel_initializer='he_normal')(img_input)
x = BatchNormalization(axis=channel_dim, momentum=0.9997,
epsilon=1e-3, name='stem_bn1')(x)
p = None
if not skip_reduction: # imagenet / mobile mode
x, p = _reduction_a_cell(x, p, filters // (filter_multiplier ** 2),
block_id='stem_1')
x, p = _reduction_a_cell(x, p, filters // filter_multiplier,
block_id='stem_2')
x, p = _reduction_a_cell(x, p, filters // (filter_multiplier ** 2),
block_id='stem_1')
x, p = _reduction_a_cell(x, p, filters // filter_multiplier,
block_id='stem_2')
for i in range(num_blocks):
x, p = _normal_a_cell(x, p, filters, block_id='%d' % (i))
......@@ -274,27 +268,32 @@ def NASNet(input_shape=None,
if weights == 'imagenet':
if default_size == 224: # mobile version
if include_top:
weight_path = NASNET_MOBILE_WEIGHT_PATH
model_name = 'nasnet_mobile.h5'
weights_path = get_file(
'nasnet_mobile.h5',
NASNET_MOBILE_WEIGHT_PATH,
cache_subdir='models',
file_hash='020fb642bf7360b370c678b08e0adf61')
else:
weight_path = NASNET_MOBILE_WEIGHT_PATH_NO_TOP
model_name = 'nasnet_mobile_no_top.h5'
weights_file = get_file(model_name, weight_path,
cache_subdir='models')
model.load_weights(weights_file)
weights_path = get_file(
'nasnet_mobile_no_top.h5',
NASNET_MOBILE_WEIGHT_PATH_NO_TOP,
cache_subdir='models',
file_hash='1ed92395b5b598bdda52abe5c0dbfd63')
model.load_weights(weights_path)
elif default_size == 331: # large version
if include_top:
weight_path = NASNET_LARGE_WEIGHT_PATH
model_name = 'nasnet_large.h5'
weights_path = get_file(
'nasnet_large.h5',
NASNET_LARGE_WEIGHT_PATH,
cache_subdir='models',
file_hash='11577c9a518f0070763c2b964a382f17')
else:
weight_path = NASNET_LARGE_WEIGHT_PATH_NO_TOP
model_name = 'nasnet_large_no_top.h5'
weights_file = get_file(model_name, weight_path,
cache_subdir='models')
model.load_weights(weights_file)
weights_path = get_file(
'nasnet_large_no_top.h5',
NASNET_LARGE_WEIGHT_PATH_NO_TOP,
cache_subdir='models',
file_hash='d81d89dc07e6e56530c4e77faddd61b5')
model.load_weights(weights_path)
else:
raise ValueError(
'ImageNet weights can only be loaded with NASNetLarge'
......@@ -364,7 +363,7 @@ def NASNetLarge(input_shape=None,
penultimate_filters=4032,
num_blocks=6,
stem_block_filters=96,
skip_reduction=False,
skip_reduction=True,
filter_multiplier=2,
include_top=include_top,
weights=weights,
......@@ -630,7 +629,7 @@ def _reduction_a_cell(ip, p, filters, block_id=None):
x1_1 = _separable_conv_block(h, filters, (5, 5), strides=(2, 2),
block_id='reduction_left1_%s' % block_id)
x1_2 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2),
block_id='reduction_1_%s' % block_id)
block_id='reduction_right1_%s' % block_id)
x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % block_id)
with K.name_scope('block_2'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册