未验证 提交 eeef62b3 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix PREN export and infer (#7833)

上级 077196f3
......@@ -21,124 +21,165 @@ from __future__ import division
from __future__ import print_function
import math
from collections import namedtuple
import re
import collections
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['EfficientNetb3']
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
'drop_connect_rate', 'image_size'
])
class EffB3Params:
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'stride', 'se_ratio'
])
class BlockDecoder:
@staticmethod
def get_global_params():
"""
The fllowing are efficientnetb3's arch superparams, but to fit for scene
text recognition task, the resolution(image_size) here is changed
from 300 to 64.
"""
GlobalParams = namedtuple('GlobalParams', [
'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'image_size'
])
global_params = GlobalParams(
drop_connect_rate=0.3,
width_coefficient=1.2,
depth_coefficient=1.4,
depth_divisor=8,
image_size=64)
return global_params
def _decode_block_string(block_string):
assert isinstance(block_string, str)
ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
return BlockArgs(
kernel_size=int(options['k']),
num_repeat=int(options['r']),
input_filters=int(options['i']),
output_filters=int(options['o']),
expand_ratio=int(options['e']),
id_skip=('noskip' not in block_string),
se_ratio=float(options['se']) if 'se' in options else None,
stride=[int(options['s'][0])])
@staticmethod
def get_block_params():
BlockParams = namedtuple('BlockParams', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'se_ratio', 'stride'
])
block_params = [
BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
BlockParams(3, 2, 16, 24, 6, True, 0.25, 2),
BlockParams(5, 2, 24, 40, 6, True, 0.25, 2),
BlockParams(3, 3, 40, 80, 6, True, 0.25, 2),
BlockParams(5, 3, 80, 112, 6, True, 0.25, 1),
BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
BlockParams(3, 1, 192, 320, 6, True, 0.25, 1)
]
return block_params
def decode(string_list):
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
def efficientnet(width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
drop_connect_rate=0.2,
image_size=None,
num_classes=1000):
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25',
'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25',
'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25',
'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
image_size=image_size, )
return blocks_args, global_params
class EffUtils:
@staticmethod
def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier."""
""" Calculate and round number of filters based on depth multiplier. """
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
new_filters = int(filters + divisor / 2) // divisor * divisor
min_depth = min_depth or divisor
new_filters = max(min_depth,
int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters:
new_filters += divisor
return int(new_filters)
@staticmethod
def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier."""
""" Round number of filters based on depth multiplier. """
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
class ConvBlock(nn.Layer):
def __init__(self, block_params):
super(ConvBlock, self).__init__()
self.block_args = block_params
self.has_se = (self.block_args.se_ratio is not None) and \
(0 < self.block_args.se_ratio <= 1)
self.id_skip = block_params.id_skip
class MbConvBlock(nn.Layer):
def __init__(self, block_args):
super(MbConvBlock, self).__init__()
self._block_args = block_args
self.has_se = (self._block_args.se_ratio is not None) and \
(0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip
# expansion phase
self.input_filters = self.block_args.input_filters
output_filters = \
self.block_args.input_filters * self.block_args.expand_ratio
if self.block_args.expand_ratio != 1:
self.expand_conv = nn.Conv2D(
self.input_filters, output_filters, 1, bias_attr=False)
self.bn0 = nn.BatchNorm(output_filters)
self.inp = self._block_args.input_filters
oup = self._block_args.input_filters * self._block_args.expand_ratio
if self._block_args.expand_ratio != 1:
self._expand_conv = nn.Conv2D(self.inp, oup, 1, bias_attr=False)
self._bn0 = nn.BatchNorm(oup)
# depthwise conv phase
k = self.block_args.kernel_size
s = self.block_args.stride
self.depthwise_conv = nn.Conv2D(
output_filters,
output_filters,
groups=output_filters,
k = self._block_args.kernel_size
s = self._block_args.stride
if isinstance(s, list):
s = s[0]
self._depthwise_conv = nn.Conv2D(
oup,
oup,
groups=oup,
kernel_size=k,
stride=s,
padding='same',
bias_attr=False)
self.bn1 = nn.BatchNorm(output_filters)
self._bn1 = nn.BatchNorm(oup)
# squeeze and excitation layer, if desired
if self.has_se:
num_squeezed_channels = max(1,
int(self.block_args.input_filters *
self.block_args.se_ratio))
self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1)
self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1)
# output phase
self.final_oup = self.block_args.output_filters
self.project_conv = nn.Conv2D(
output_filters, self.final_oup, 1, bias_attr=False)
self.bn2 = nn.BatchNorm(self.final_oup)
self.swish = nn.Swish()
def drop_connect(self, inputs, p, training):
int(self._block_args.input_filters *
self._block_args.se_ratio))
self._se_reduce = nn.Conv2D(oup, num_squeezed_channels, 1)
self._se_expand = nn.Conv2D(num_squeezed_channels, oup, 1)
# output phase and some util class
self.final_oup = self._block_args.output_filters
self._project_conv = nn.Conv2D(oup, self.final_oup, 1, bias_attr=False)
self._bn2 = nn.BatchNorm(self.final_oup)
self._swish = nn.Swish()
def _drop_connect(self, inputs, p, training):
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
......@@ -151,22 +192,23 @@ class ConvBlock(nn.Layer):
def forward(self, inputs, drop_connect_rate=None):
# expansion and depthwise conv
x = inputs
if self.block_args.expand_ratio != 1:
x = self.swish(self.bn0(self.expand_conv(inputs)))
x = self.swish(self.bn1(self.depthwise_conv(x)))
if self._block_args.expand_ratio != 1:
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))
# squeeze and excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed)))
x_squeezed = self._se_expand(
self._swish(self._se_reduce(x_squeezed)))
x = F.sigmoid(x_squeezed) * x
x = self.bn2(self.project_conv(x))
x = self._bn2(self._project_conv(x))
# skip conntection and drop connect
if self.id_skip and self.block_args.stride == 1 and \
self.input_filters == self.final_oup:
if self.id_skip and self._block_args.stride == 1 and \
self.inp == self.final_oup:
if drop_connect_rate:
x = self.drop_connect(
x = self._drop_connect(
x, p=drop_connect_rate, training=self.training)
x = x + inputs
return x
......@@ -175,54 +217,63 @@ class ConvBlock(nn.Layer):
class EfficientNetb3_PREN(nn.Layer):
def __init__(self, in_channels):
super(EfficientNetb3_PREN, self).__init__()
self.blocks_params = EffB3Params.get_block_params()
self.global_params = EffB3Params.get_global_params()
"""
the fllowing are efficientnetb3's superparams,
they means efficientnetb3 network's width, depth, resolution and
dropout respectively, to fit for text recognition task, the resolution
here is changed from 300 to 64.
"""
w, d, s, p = 1.2, 1.4, 64, 0.3
self._blocks_args, self._global_params = efficientnet(
width_coefficient=w,
depth_coefficient=d,
dropout_rate=p,
image_size=s)
self.out_channels = []
# stem
stem_channels = EffUtils.round_filters(32, self.global_params)
self.conv_stem = nn.Conv2D(
in_channels, stem_channels, 3, 2, padding='same', bias_attr=False)
self.bn0 = nn.BatchNorm(stem_channels)
out_channels = EffUtils.round_filters(32, self._global_params)
self._conv_stem = nn.Conv2D(
in_channels, out_channels, 3, 2, padding='same', bias_attr=False)
self._bn0 = nn.BatchNorm(out_channels)
self.blocks = []
# build blocks
self._blocks = []
# to extract three feature maps for fpn based on efficientnetb3 backbone
self.concerned_block_idxes = [7, 17, 25]
concerned_idx = 0
for i, block_params in enumerate(self.blocks_params):
block_params = block_params._replace(
input_filters=EffUtils.round_filters(block_params.input_filters,
self.global_params),
output_filters=EffUtils.round_filters(
block_params.output_filters, self.global_params),
num_repeat=EffUtils.round_repeats(block_params.num_repeat,
self.global_params))
self.blocks.append(
self.add_sublayer("{}-0".format(i), ConvBlock(block_params)))
concerned_idx += 1
if concerned_idx in self.concerned_block_idxes:
self.out_channels.append(block_params.output_filters)
if block_params.num_repeat > 1:
block_params = block_params._replace(
input_filters=block_params.output_filters, stride=1)
for j in range(block_params.num_repeat - 1):
self.blocks.append(
self.add_sublayer('{}-{}'.format(i, j + 1),
ConvBlock(block_params)))
concerned_idx += 1
if concerned_idx in self.concerned_block_idxes:
self.out_channels.append(block_params.output_filters)
self.swish = nn.Swish()
self._concerned_block_idxes = [7, 17, 25]
_concerned_idx = 0
for i, block_args in enumerate(self._blocks_args):
block_args = block_args._replace(
input_filters=EffUtils.round_filters(block_args.input_filters,
self._global_params),
output_filters=EffUtils.round_filters(block_args.output_filters,
self._global_params),
num_repeat=EffUtils.round_repeats(block_args.num_repeat,
self._global_params))
self._blocks.append(
self.add_sublayer(f"{i}-0", MbConvBlock(block_args)))
_concerned_idx += 1
if _concerned_idx in self._concerned_block_idxes:
self.out_channels.append(block_args.output_filters)
if block_args.num_repeat > 1:
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1)
for j in range(block_args.num_repeat - 1):
self._blocks.append(
self.add_sublayer(f'{i}-{j+1}', MbConvBlock(block_args)))
_concerned_idx += 1
if _concerned_idx in self._concerned_block_idxes:
self.out_channels.append(block_args.output_filters)
self._swish = nn.Swish()
def forward(self, inputs):
outs = []
x = self.swish(self.bn0(self.conv_stem(inputs)))
for idx, block in enumerate(self.blocks):
drop_connect_rate = self.global_params.drop_connect_rate
x = self._swish(self._bn0(self._conv_stem(inputs)))
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.blocks)
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
if idx in self.concerned_block_idxes:
if idx in self._concerned_block_idxes:
outs.append(x)
return outs
......@@ -562,7 +562,8 @@ class PRENLabelDecode(BaseRecLabelDecode):
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
preds = preds.numpy()
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob)
......
......@@ -77,7 +77,7 @@ def export_single_model(model,
elif arch_config["algorithm"] == "PREN":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 512], dtype="float32"),
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["model_type"] == "sr":
......
......@@ -100,6 +100,8 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char,
"rm_symbol": True
}
elif self.rec_algorithm == "PREN":
postprocess_params = {'name': 'PRENLabelDecode'}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -384,7 +386,7 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "VisionLAN":
elif self.rec_algorithm in ["VisionLAN", "PREN"]:
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册