提交 9b519f0d 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add effnet_small_eval

上级 f4430af8
...@@ -22,7 +22,6 @@ from __future__ import unicode_literals ...@@ -22,7 +22,6 @@ from __future__ import unicode_literals
import six import six
import math import math
import random import random
import functools
import cv2 import cv2
import numpy as np import numpy as np
...@@ -38,8 +37,8 @@ class DecodeImage(object): ...@@ -38,8 +37,8 @@ class DecodeImage(object):
def __init__(self, to_rgb=True, to_np=False, channel_first=False): def __init__(self, to_rgb=True, to_np=False, channel_first=False):
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.to_np = to_np #to numpy self.to_np = to_np # to numpy
self.channel_first = channel_first #only enabled when to_np is True self.channel_first = channel_first # only enabled when to_np is True
def __call__(self, img): def __call__(self, img):
if six.PY2: if six.PY2:
...@@ -64,7 +63,8 @@ class DecodeImage(object): ...@@ -64,7 +63,8 @@ class DecodeImage(object):
class ResizeImage(object): class ResizeImage(object):
""" resize image """ """ resize image """
def __init__(self, size=None, resize_short=None): def __init__(self, size=None, resize_short=None, interpolation=-1):
self.interpolation = interpolation if interpolation >= 0 else None
if resize_short is not None and resize_short > 0: if resize_short is not None and resize_short > 0:
self.resize_short = resize_short self.resize_short = resize_short
self.w = None self.w = None
...@@ -86,8 +86,10 @@ class ResizeImage(object): ...@@ -86,8 +86,10 @@ class ResizeImage(object):
else: else:
w = self.w w = self.w
h = self.h h = self.h
if self.interpolation is None:
return cv2.resize(img, (w, h)) return cv2.resize(img, (w, h))
else:
return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object): class CropImage(object):
...@@ -138,8 +140,7 @@ class RandCropImage(object): ...@@ -138,8 +140,7 @@ class RandCropImage(object):
scale_max = min(scale[1], bound) scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound) scale_min = min(scale[0], bound)
target_area = img_w * img_h * random.uniform(\ target_area = img_w * img_h * random.uniform(scale_min, scale_max)
scale_min, scale_max)
target_size = math.sqrt(target_area) target_size = math.sqrt(target_area)
w = int(target_size * w) w = int(target_size * w)
h = int(target_size * h) h = int(target_size * h)
...@@ -176,7 +177,8 @@ class NormalizeImage(object): ...@@ -176,7 +177,8 @@ class NormalizeImage(object):
""" """
def __init__(self, scale=None, mean=None, std=None, order='chw'): def __init__(self, scale=None, mean=None, std=None, order='chw'):
if isinstance(scale, str): scale = eval(scale) if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406] mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225] std = std if std is not None else [0.229, 0.224, 0.225]
......
...@@ -36,7 +36,7 @@ from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseN ...@@ -36,7 +36,7 @@ from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseN
from .squeezenet import SqueezeNet1_0, SqueezeNet1_1 from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from .darknet import DarkNet53 from .darknet import DarkNet53
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl
from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB0_small, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from .res2net import Res2Net50_48w_2s, Res2Net50_26w_4s, Res2Net50_14w_8s, Res2Net50_26w_6s, Res2Net50_26w_8s, Res2Net101_26w_4s, Res2Net152_26w_4s from .res2net import Res2Net50_48w_2s, Res2Net50_26w_4s, Res2Net50_14w_8s, Res2Net50_26w_6s, Res2Net50_26w_8s, Res2Net101_26w_4s, Res2Net152_26w_4s
from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_14w_8s, Res2Net50_vd_26w_6s, Res2Net50_vd_26w_8s, Res2Net101_vd_26w_4s, Res2Net152_vd_26w_4s, Res2Net200_vd_26w_4s from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_14w_8s, Res2Net50_vd_26w_6s, Res2Net50_vd_26w_8s, Res2Net101_vd_26w_4s, Res2Net152_vd_26w_4s, Res2Net200_vd_26w_4s
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -192,15 +192,17 @@ class EfficientNet(): ...@@ -192,15 +192,17 @@ class EfficientNet():
if is_test: if is_test:
return inputs return inputs
keep_prob = 1.0 - prob keep_prob = 1.0 - prob
random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like( random_tensor = keep_prob + \
inputs, [-1, 1, 1, 1], min=0., max=1.) fluid.layers.uniform_random_batch_size_like(
inputs, [-1, 1, 1, 1], min=0., max=1.)
binary_tensor = fluid.layers.floor(random_tensor) binary_tensor = fluid.layers.floor(random_tensor)
output = inputs / keep_prob * binary_tensor output = inputs / keep_prob * binary_tensor
return output return output
def _expand_conv_norm(self, inputs, block_args, is_test, name=None): def _expand_conv_norm(self, inputs, block_args, is_test, name=None):
# Expansion phase # Expansion phase
oup = block_args.input_filters * block_args.expand_ratio # number of output channels oup = block_args.input_filters * \
block_args.expand_ratio # number of output channels
if block_args.expand_ratio != 1: if block_args.expand_ratio != 1:
conv = self.conv_bn_layer( conv = self.conv_bn_layer(
...@@ -222,7 +224,8 @@ class EfficientNet(): ...@@ -222,7 +224,8 @@ class EfficientNet():
s = block_args.stride s = block_args.stride
if isinstance(s, list) or isinstance(s, tuple): if isinstance(s, list) or isinstance(s, tuple):
s = s[0] s = s[0]
oup = block_args.input_filters * block_args.expand_ratio # number of output channels oup = block_args.input_filters * \
block_args.expand_ratio # number of output channels
conv = self.conv_bn_layer( conv = self.conv_bn_layer(
inputs, inputs,
...@@ -285,7 +288,7 @@ class EfficientNet(): ...@@ -285,7 +288,7 @@ class EfficientNet():
name=conv_name, name=conv_name,
use_bias=use_bias) use_bias=use_bias)
if use_bn == False: if use_bn is False:
return conv return conv
else: else:
bn_name = name + bn_name bn_name = name + bn_name
...@@ -325,7 +328,8 @@ class EfficientNet(): ...@@ -325,7 +328,8 @@ class EfficientNet():
drop_connect_rate=None, drop_connect_rate=None,
name=None): name=None):
# Expansion and Depthwise Convolution # Expansion and Depthwise Convolution
oup = block_args.input_filters * block_args.expand_ratio # number of output channels oup = block_args.input_filters * \
block_args.expand_ratio # number of output channels
has_se = self.use_se and (block_args.se_ratio is not None) and ( has_se = self.use_se and (block_args.se_ratio is not None) and (
0 < block_args.se_ratio <= 1) 0 < block_args.se_ratio <= 1)
id_skip = block_args.id_skip # skip connection and drop connect id_skip = block_args.id_skip # skip connection and drop connect
...@@ -346,8 +350,11 @@ class EfficientNet(): ...@@ -346,8 +350,11 @@ class EfficientNet():
conv = self._project_conv_norm(conv, block_args, is_test, name) conv = self._project_conv_norm(conv, block_args, is_test, name)
# Skip connection and drop connect # Skip connection and drop connect
input_filters, output_filters = block_args.input_filters, block_args.output_filters input_filters = block_args.input_filters
if id_skip and block_args.stride == 1 and input_filters == output_filters: output_filters = block_args.output_filters
if id_skip and \
block_args.stride == 1 and \
input_filters == output_filters:
if drop_connect_rate: if drop_connect_rate:
conv = self._drop_connect(conv, drop_connect_rate, conv = self._drop_connect(conv, drop_connect_rate,
self.is_test) self.is_test)
...@@ -412,7 +419,8 @@ class EfficientNet(): ...@@ -412,7 +419,8 @@ class EfficientNet():
num_repeat=round_repeats(block_args.num_repeat, num_repeat=round_repeats(block_args.num_repeat,
self._global_params)) self._global_params))
# The first block needs to take care of stride and filter size increase. # The first block needs to take care of stride,
# and filter size increase.
drop_connect_rate = self._global_params.drop_connect_rate drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate: if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size drop_connect_rate *= float(idx) / block_size
...@@ -440,7 +448,9 @@ class EfficientNet(): ...@@ -440,7 +448,9 @@ class EfficientNet():
class BlockDecoder(object): class BlockDecoder(object):
""" Block Decoder for readability, straight from the official TensorFlow repository """ """
Block Decoder, straight from the official TensorFlow repository.
"""
@staticmethod @staticmethod
def _decode_block_string(block_string): def _decode_block_string(block_string):
...@@ -456,9 +466,10 @@ class BlockDecoder(object): ...@@ -456,9 +466,10 @@ class BlockDecoder(object):
options[key] = value options[key] = value
# Check stride # Check stride
assert ( cond_1 = ('s' in options and len(options['s']) == 1)
('s' in options and len(options['s']) == 1) or cond_2 = ((len(options['s']) == 2)
(len(options['s']) == 2 and options['s'][0] == options['s'][1])) and (options['s'][0] == options['s'][1]))
assert (cond_1 or cond_2)
return BlockArgs( return BlockArgs(
kernel_size=int(options['k']), kernel_size=int(options['k']),
...@@ -487,10 +498,11 @@ class BlockDecoder(object): ...@@ -487,10 +498,11 @@ class BlockDecoder(object):
@staticmethod @staticmethod
def decode(string_list): def decode(string_list):
""" """
Decodes a list of string notations to specify blocks inside the network. Decode a list of string notations to specify blocks in the network.
:param string_list: a list of strings, each string is a notation of block string_list: list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args return
list of BlockArgs namedtuples of block args
""" """
assert isinstance(string_list, list) assert isinstance(string_list, list)
blocks_args = [] blocks_args = []
...@@ -525,6 +537,19 @@ def EfficientNetB0(is_test=False, ...@@ -525,6 +537,19 @@ def EfficientNetB0(is_test=False,
return model return model
def EfficientNetB0_small(is_test=False,
padding_type='DYNAMIC',
override_params=None,
use_se=False):
model = EfficientNet(
name='b0',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se)
return model
def EfficientNetB1(is_test=False, def EfficientNetB1(is_test=False,
padding_type='SAME', padding_type='SAME',
override_params=None, override_params=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册