提交 88786bc8 编写于 作者: 农夫山泉2号's avatar 农夫山泉2号 提交者: chenjun2hao

updata the decoder model for indefinite width image recognition

Signed-off-by: Nchenjun2hao <chenjun01@ebupt.com>
上级 45e3e471
# coding:utf-8
import utils
import torch
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import models.crnn_lang as crnn
class attention_ocr():
'''使用attention_ocr进行字符识别
返回:
ocr读数,置信度
'''
def __init__(self):
encoder_path = './expr/attentioncnn/encoder_600.pth'
decoder_path = './expr/attentioncnn/decoder_600.pth'
self.alphabet = '0123456789'
self.max_length = 7 # 最长字符串的长度
self.EOS_TOKEN = 1
self.use_gpu = True
self.max_width = 220
self.converter = utils.strLabelConverterForAttention(self.alphabet)
self.transform = transforms.ToTensor()
nclass = len(self.alphabet) + 3
encoder = crnn.CNN(32, 1, 256) # 编码器
decoder = crnn.decoder(256, nclass) # seq to seq的解码器, nclass在decoder中还加了2
if encoder_path and decoder_path:
print('loading pretrained models ......')
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))
if torch.cuda.is_available() and self.use_gpu:
encoder = encoder.cuda()
decoder = decoder.cuda()
self.encoder = encoder.eval()
self.decoder = decoder.eval()
def constant_pad(self, img_crop):
'''把图片等比例缩放到高度:32,再或resize填充到220宽度
img_crop:
cv2图片,rgb顺序
返回:
img tensor
'''
h, w, c = img_crop.shape
ratio = h / 32
new_w = int(w / ratio)
new_img = cv2.resize(img_crop,(new_w, 32))
container = np.ones((32, self.max_width, 3), dtype=np.uint8) * new_img[-3,-3,:]
if new_w <= self.max_width:
container[:,:new_w,:] = new_img
elif new_w > self.max_width:
container = cv2.resize(new_img, (self.max_width, 32))
img = Image.fromarray(container.astype('uint8')).convert('L')
img = self.transform(img)
img.sub_(0.5).div_(0.5)
if self.use_gpu:
img = img.cuda()
return img.unsqueeze(0)
def predict(self, img_crop):
'''attention ocr 做文字识别
img_crop:
cv2图片,rgb顺序
返回:
ocr读数,prob置信度
'''
img_tensor = self.constant_pad(img_crop)
encoder_out = self.encoder(img_tensor)
decoded_words = []
prob = 1.0
decoder_input = torch.zeros(1).long() # 初始化decoder的开始,从0开始输出
decoder_hidden = self.decoder.initHidden(1)
if torch.cuda.is_available() and self.use_gpu:
decoder_input = decoder_input.cuda()
decoder_hidden = decoder_hidden.cuda()
# 预测的时候采用非强制策略,将前一次的输出,作为下一次的输入,直到标签为EOS_TOKEN时停止
for di in range(self.max_length): # 最大字符串的长度
decoder_output, decoder_hidden, decoder_attention = self.decoder(
decoder_input, decoder_hidden, encoder_out)
probs = torch.exp(decoder_output)
# decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
ni = topi.squeeze(1)
decoder_input = ni
prob *= probs[:, ni]
if ni == self.EOS_TOKEN:
# decoded_words.append('<EOS>')
break
else:
decoded_words.append(self.converter.decode(ni))
words = ''.join(decoded_words)
prob = prob.item()
return words, prob
if __name__ == '__main__':
path = './test_img/00027_299021_27.jpg'
img = cv2.imread(path)
attention = attention_ocr()
res = attention.predict(img)
print(res)
\ No newline at end of file
#!/usr/bin/python
# encoding: utf-8
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
# import lmdb
import six
import sys
from PIL import Image
import numpy as np
class listDataset(Dataset):
def __init__(self, list_file=None, transform=None, target_transform=None):
self.list_file = list_file
with open(list_file) as fp:
self.lines = fp.readlines()
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
line_splits = self.lines[index].strip().split(' ')
imgpath = line_splits[0]
try:
if 'train' in self.list_file:
img = Image.open(imgpath).convert('L')
else:
img = Image.open(imgpath).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label = line_splits[1].decode('utf-8')
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
return img
class randomSequentialSampler(sampler.Sampler):
def __init__(self, data_source, batch_size):
self.num_samples = len(data_source)
self.batch_size = batch_size
def __iter__(self):
n_batch = len(self) // self.batch_size
tail = len(self) % self.batch_size
index = torch.LongTensor(len(self)).fill_(0)
for i in range(n_batch):
random_start = random.randint(0, len(self) - self.batch_size)
batch_index = random_start + torch.arange(0, self.batch_size)
index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
# deal with tail
if tail:
random_start = random.randint(0, len(self) - self.batch_size)
tail_index = random_start + torch.arange(0, tail)
index[(i + 1) * self.batch_size:] = tail_index
return iter(index)
def __len__(self):
return self.num_samples
class alignCollate(object):
def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
def __call__(self, batch):
images, labels = zip(*batch)
imgH = self.imgH
imgW = self.imgW
if self.keep_ratio:
ratios = []
for image in images:
w, h = image.size
ratios.append(w / float(h))
ratios.sort()
max_ratio = ratios[-1]
imgW = int(np.floor(max_ratio * imgH))
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
transform = resizeNormalize((imgW, imgH))
images = [transform(image) for image in images]
images = torch.cat([t.unsqueeze(0) for t in images], 0)
return images, labels
#!/usr/bin/python
# encoding: utf-8
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import collections
from PIL import Image, ImageFilter
import math
import random
import numpy as np
import cv2
with open('./data/char_std_5990.txt') as f:
data = f.readlines()
alphabet = [x.rstrip() for x in data]
alphabet = ''.join(alphabet).decode('utf-8') # python2不加decode的时候会乱码
class strLabelConverterForAttention(object):
"""Convert between str and label.
NOTE:
Insert `EOS` to the alphabet for attention.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet):
self.alphabet = alphabet
self.dict = {}
self.dict['SOS'] = 0 # 开始
self.dict['EOS'] = 1 # 结束
self.dict['$'] = 2 # blank标识符
for i, item in enumerate(self.alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[item] = i + 3 # 从3开始编码
def encode(self, text):
"""对target_label做编码和对齐
对target txt每个字符串的开始加上GO,最后加上EOS,并用最长的字符串做对齐
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor targets:max_length × batch_size
"""
if isinstance(text, unicode):
text = [self.dict[item] for item in text]
elif isinstance(text, collections.Iterable):
text = [self.encode(s) for s in text] # 编码
max_length = max([len(x) for x in text]) # 对齐
nb = len(text)
targets = torch.ones(nb, max_length + 2) * 2 # use ‘blank’ for pading
for i in range(nb):
targets[i][0] = 0 # 开始
targets[i][1:len(text[i]) + 1] = text[i]
targets[i][len(text[i]) + 1] = 1
text = targets.transpose(0, 1).contiguous()
text = text.long()
return torch.LongTensor(text)
def decode(self, t):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
texts = list(self.dict.keys())[list(self.dict.values()).index(t)]
return texts
class strLabelConverterForCTC(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, sep):
self.sep = sep
self.alphabet = alphabet.split(sep)
self.alphabet.append('-') # for `-1` index
self.dict = {}
for i, item in enumerate(self.alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[item] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = text.split(self.sep)
text = [self.dict[item] for item in text]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s.split(self.sep)) for s in text]
text = self.sep.join(text)
text, _ = self.encode(text)
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
class averager(object):
"""Compute average for `torch.Variable` and `torch.Tensor`. """
def __init__(self):
self.reset()
def add(self, v):
if isinstance(v, Variable):
count = v.data.numel()
v = v.data.sum()
elif isinstance(v, torch.Tensor):
count = v.numel()
v = v.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
def oneHot(v, v_length, nc):
batchSize = v_length.size(0)
maxLength = v_length.max()
v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
acc = 0
for i in range(batchSize):
length = v_length[i]
label = v[acc:acc + length].view(-1, 1).long()
v_onehot[i, :length].scatter_(1, label, 1.0)
acc += length
return v_onehot
def loadData(v, data):
v.data.resize_(data.size()).copy_(data)
def prettyPrint(v):
print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0],
v.mean().data[0]))
def assureRatio(img):
"""Ensure imgH <= imgW."""
b, c, h, w = img.size()
if h > w:
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img
class halo():
'''
u:高斯分布的均值
sigma:方差
nums:在一张图片中随机添加几个光点
prob:使用halo的概率
'''
def __init__(self, nums, u=0, sigma=0.2, prob=0.5):
self.u = u # 均值μ
self.sig = math.sqrt(sigma) # 标准差δ
self.nums = nums
self.prob = prob
def create_kernel(self, maxh=32, maxw=50):
height_scope = [10, maxh] # 高度范围 随机生成高斯
weight_scope = [20, maxw] # 宽度范围
x = np.linspace(self.u - 3 * self.sig, self.u + 3 * self.sig, random.randint(*height_scope))
y = np.linspace(self.u - 3 * self.sig, self.u + 3 * self.sig, random.randint(*weight_scope))
Gauss_map = np.zeros((len(x), len(y)))
for i in range(len(x)):
for j in range(len(y)):
Gauss_map[i, j] = np.exp(-((x[i] - self.u) ** 2 + (y[j] - self.u) ** 2) / (2 * self.sig ** 2)) / (
math.sqrt(2 * math.pi) * self.sig)
return Gauss_map
def __call__(self, img):
if random.random() < self.prob:
Gauss_map = self.create_kernel(32, 60) # 初始化一个高斯核,32为高度方向的最大值,60为w方向
img1 = np.asarray(img)
img1.flags.writeable = True # 将数组改为读写模式
nums = random.randint(1, self.nums) # 随机生成nums个光点
img1 = img1.astype(np.float)
# print(nums)
for i in range(nums):
img_h, img_w = img1.shape
pointx = random.randint(0, img_h - 10) # 在原图中随机找一个点
pointy = random.randint(0, img_w - 10)
h, w = Gauss_map.shape # 判断是否超限
endx = pointx + h
endy = pointy + w
if pointx + h > img_h:
endx = img_h
Gauss_map = Gauss_map[1:img_h - pointx + 1, :]
if img_w < pointy + w:
endy = img_w
Gauss_map = Gauss_map[:, 1:img_w - pointy + 1]
# 加上不均匀光照
img1[pointx:endx, pointy:endy] = img1[pointx:endx, pointy:endy] + Gauss_map * 255.0
img1[img1 > 255.0] = 255.0 # 进行限幅,不然uint8会从0开始重新计数
img = img1
return Image.fromarray(np.uint8(img))
class MyGaussianBlur(ImageFilter.Filter):
name = "GaussianBlur"
def __init__(self, radius=2, bounds=None):
self.radius = radius
self.bounds = bounds
def filter(self, image):
if self.bounds:
clips = image.crop(self.bounds).gaussian_blur(self.radius)
image.paste(clips, self.bounds)
return image
else:
return image.gaussian_blur(self.radius)
class GBlur(object):
def __init__(self, radius=2, prob=0.5):
radius = random.randint(0, radius)
self.blur = MyGaussianBlur(radius=radius)
self.prob = prob
def __call__(self, img):
if random.random() < self.prob:
img = img.filter(self.blur)
return img
class RandomBrightness(object):
"""随机改变亮度
pil:pil格式的图片
"""
def __init__(self, prob=1.5):
self.prob = prob
def __call__(self, pil):
rgb = np.asarray(pil)
if random.random() < self.prob:
hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)
h, s, v = cv2.split(hsv)
adjust = random.choice([0.5, 0.7, 0.9, 1.2, 1.5, 1.7]) # 随机选择一个
# adjust = random.choice([1.2, 1.5, 1.7, 2.0]) # 随机选择一个
v = v * adjust
v = np.clip(v, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h, s, v))
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(np.uint8(rgb)).convert('L')
class randapply(object):
"""随机决定是否应用光晕、模糊或者二者都用
Args:
transforms (list or tuple): list of transformations
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
def weights_init(model):
# Official init from torch repo.
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册