提交 45e3e471 编写于 作者: 农夫山泉2号's avatar 农夫山泉2号 提交者: chenjun2hao

updata the decoder model for indefinite width image recognition

Signed-off-by: Nchenjun2hao <chenjun01@ebupt.com>
上级 f5b397d7
attention-ocr.pytorch:Encoder+Decoder+attention model
======================================
This repository implements the the encoder and decoder model with attention model for OCR, and this repository is modified from https://github.com/meijieru/crnn.pytorch
There are many attention ocr repository which is finished with tensorflow, but they don't give the **inference.py**, besides i'm not good at tensorflow, i can't finish the **inference.py** by myself
This repository implements the the encoder and decoder model with attention model for OCR, the encoder uses CNN+Bi-LSTM, the decoder uses GRU. This repository is modified from https://github.com/meijieru/crnn.pytorch
Earlier I had an open source version, but had some problems identifying images of fixed width. Recently I modified the model to support image recognition with variable width. The function is the same as CRNN. Due to the time problem, there is no pre-training model this time, which will be updated later.
# requirements
pytorch 0.4.1
......@@ -13,30 +13,9 @@ pip install -r requirements.txt
```
# Test
1. download the pretrained model from [attentioncnn.zip](https://pan.baidu.com/s/1h5d7rtWqfZaKHtm52ZWVFw) ,and put the pth files into the folder:expr/attentioncnn/
pretrained model coming soon
2. change the test image's name in **demo.py**,the test images are in the test_img folder
```bash
python demo.py
```
3. results
![](./test_img/20441531_4212871437.jpg)
```
>>>predict_str:比赛,尽管很尽心尽力 => prob:0.6112725138664246
```
4. some examples
![结果.jpg](./test_img/md_img/attention结果.png)
| picture | predict reading | confidence |
| ------ | ------ | ------ |
| ![](./test_img/20436312_1683447152.jpg) | 美国人不愿意与制鲜 | 0.33920 |
| ![](./test_img/20437109_1639581473.jpg) | 现之间的一个分享和 | 0.81095 |
| ![](./test_img/20437421_2143654630.jpg) | 中国通信学会主办、《 | 0.90660 |
| ![](./test_img/20437531_1514396900.jpg) | 此在战术上应大胆、勇 | 0.57111 |
| ![](./test_img/20439281_953270478.jpg) | 年同期俱点83.50 | 0.14481 |
| ![](./test_img/20439906_2889507409.jpg) | 。留言无恶意心态成 | 0.31054 |
5. the accuracy in the test data just stay around 88%, there is much thing to do
# Train
1. Here i choose a small dataset from [Synthetic_Chinese_String_Dataset](https://github.com/chenjun2hao/caffe_ocr), about 270000+ images for training, 20000 images for testing.
download the image data from [Baidu](https://pan.baidu.com/s/1hIurFJ73XbzL-QG4V-oe0w)
......@@ -53,7 +32,8 @@ cd Attention_ocr.pytorch
python train.py --trainlist ./data/ch_train.txt --vallist ./data/ch_test.txt
```
then you can see in the terminel as follow:
![attentionocr](./test_img/md_img/attentionocr.png)
![attentionocr](./test_img/md_img/attentionV2.png)
there uses the decoderV2 model for decoder.
# Reference
1. [crnn.pytorch](https://github.com/meijieru/crnn.pytorch)
......@@ -63,5 +43,5 @@ then you can see in the terminel as follow:
# TO DO
- [ ] change LSTM to Conv1D, it can greatly accelerate the inference
- [ ] to support images of different widths
- [ ] change the cnn bone model with inception net, densenet
- [ ] realize the decoder with transformer model
# coding:utf-8
import utils
import torch
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import models.crnn_lang as crnn
class attention_ocr():
'''使用attention_ocr进行字符识别
返回:
ocr读数,置信度
'''
def __init__(self):
encoder_path = './expr/attentioncnn/encoder_600.pth'
decoder_path = './expr/attentioncnn/decoder_600.pth'
self.alphabet = '0123456789'
self.max_length = 7 # 最长字符串的长度
self.EOS_TOKEN = 1
self.use_gpu = True
self.max_width = 220
self.converter = utils.strLabelConverterForAttention(self.alphabet)
self.transform = transforms.ToTensor()
nclass = len(self.alphabet) + 3
encoder = crnn.CNN(32, 1, 256) # 编码器
decoder = crnn.decoder(256, nclass) # seq to seq的解码器, nclass在decoder中还加了2
if encoder_path and decoder_path:
print('loading pretrained models ......')
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))
if torch.cuda.is_available() and self.use_gpu:
encoder = encoder.cuda()
decoder = decoder.cuda()
self.encoder = encoder.eval()
self.decoder = decoder.eval()
def constant_pad(self, img_crop):
'''把图片等比例缩放到高度:32,再或resize填充到220宽度
img_crop:
cv2图片,rgb顺序
返回:
img tensor
'''
h, w, c = img_crop.shape
ratio = h / 32
new_w = int(w / ratio)
new_img = cv2.resize(img_crop,(new_w, 32))
container = np.ones((32, self.max_width, 3), dtype=np.uint8) * new_img[-3,-3,:]
if new_w <= self.max_width:
container[:,:new_w,:] = new_img
elif new_w > self.max_width:
container = cv2.resize(new_img, (self.max_width, 32))
img = Image.fromarray(container.astype('uint8')).convert('L')
img = self.transform(img)
img.sub_(0.5).div_(0.5)
if self.use_gpu:
img = img.cuda()
return img.unsqueeze(0)
def predict(self, img_crop):
'''attention ocr 做文字识别
img_crop:
cv2图片,rgb顺序
返回:
ocr读数,prob置信度
'''
img_tensor = self.constant_pad(img_crop)
encoder_out = self.encoder(img_tensor)
decoded_words = []
prob = 1.0
decoder_input = torch.zeros(1).long() # 初始化decoder的开始,从0开始输出
decoder_hidden = self.decoder.initHidden(1)
if torch.cuda.is_available() and self.use_gpu:
decoder_input = decoder_input.cuda()
decoder_hidden = decoder_hidden.cuda()
# 预测的时候采用非强制策略,将前一次的输出,作为下一次的输入,直到标签为EOS_TOKEN时停止
for di in range(self.max_length): # 最大字符串的长度
decoder_output, decoder_hidden, decoder_attention = self.decoder(
decoder_input, decoder_hidden, encoder_out)
probs = torch.exp(decoder_output)
# decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
ni = topi.squeeze(1)
decoder_input = ni
prob *= probs[:, ni]
if ni == self.EOS_TOKEN:
# decoded_words.append('<EOS>')
break
else:
decoded_words.append(self.converter.decode(ni))
words = ''.join(decoded_words)
prob = prob.item()
return words, prob
if __name__ == '__main__':
path = './test_img/00027_299021_27.jpg'
img = cv2.imread(path)
attention = attention_ocr()
res = attention.predict(img)
print(res)
\ No newline at end of file
#!/usr/bin/python
# encoding: utf-8
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
# import lmdb
import six
import sys
from PIL import Image
import numpy as np
class listDataset(Dataset):
def __init__(self, list_file=None, transform=None, target_transform=None):
self.list_file = list_file
with open(list_file) as fp:
self.lines = fp.readlines()
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
line_splits = self.lines[index].strip().split(' ')
imgpath = line_splits[0]
try:
if 'train' in self.list_file:
img = Image.open(imgpath).convert('L')
else:
img = Image.open(imgpath).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label = line_splits[1]
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
return img
class randomSequentialSampler(sampler.Sampler):
def __init__(self, data_source, batch_size):
self.num_samples = len(data_source)
self.batch_size = batch_size
def __iter__(self):
n_batch = len(self) // self.batch_size
tail = len(self) % self.batch_size
index = torch.LongTensor(len(self)).fill_(0)
for i in range(n_batch):
random_start = random.randint(0, len(self) - self.batch_size)
batch_index = random_start + torch.arange(0, self.batch_size)
index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
# deal with tail
if tail:
random_start = random.randint(0, len(self) - self.batch_size)
tail_index = random_start + torch.arange(0, tail)
index[(i + 1) * self.batch_size:] = tail_index
return iter(index)
def __len__(self):
return self.num_samples
class alignCollate(object):
def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
def __call__(self, batch):
images, labels = zip(*batch)
imgH = self.imgH
imgW = self.imgW
if self.keep_ratio:
ratios = []
for image in images:
w, h = image.size
ratios.append(w / float(h))
ratios.sort()
max_ratio = ratios[-1]
imgW = int(np.floor(max_ratio * imgH))
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
transform = resizeNormalize((imgW, imgH))
images = [transform(image) for image in images]
images = torch.cat([t.unsqueeze(0) for t in images], 0)
return images, labels
# coding:utf-8
'''
March 2019 by Chen Jun
https://github.com/chenjun2hao/Attention_ocr.pytorch
'''
import torch
from torch.autograd import Variable
import utils
......@@ -10,14 +17,16 @@ import models.crnn_lang as crnn
use_gpu = True
encoder_path = './expr/attentioncnn/encoder_5.pth'
decoder_path = './expr/attentioncnn/decoder_5.pth'
# decoder_path = './expr/attentioncnn/decoder_5.pth'
img_path = './test_img/20441531_4212871437.jpg'
max_length = 15 # 最长字符串的长度
EOS_TOKEN = 1
nclass = len(alphabet) + 3
encoder = crnn.CNN(32, 1, 256) # 编码器
decoder = crnn.decoder(256, nclass) # seq to seq的解码器, nclass在decoder中还加了2
# decoder = crnn.decoder(256, nclass) # seq to seq的解码器, nclass在decoder中还加了2
decoder = crnn.decoderV2(256, nclass)
if encoder_path and decoder_path:
print('loading pretrained models ......')
......
......@@ -109,21 +109,22 @@ class Attentiondecoder(nn.Module):
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
# calculate the attention weight and weight * encoder_output feature
embedded = self.embedding(input) # 前一次的输出进行词嵌入
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1) # 上一次的输出和隐藏状态求出权重
self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1) # 上一次的输出和隐藏状态求出权重, 主要使用一个linear layer从512维到71维,所以只能处理固定宽度的序列
attn_applied = torch.matmul(attn_weights.unsqueeze(1),
encoder_outputs.permute((1, 0, 2))) # 矩阵乘法,bmm(8×1×56,8×56×256)=8×1×256
output = torch.cat((embedded, attn_applied.squeeze(1) ), 1) # 上一次的输出和attention feature,做一个线性+GRU
output = torch.cat((embedded, attn_applied.squeeze(1) ), 1) # 上一次的输出和attention feature做一个融合,再加一个linear layer
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output, hidden = self.gru(output, hidden) # just as sequence to sequence decoder
output = F.log_softmax(self.out(output[0]), dim=1) # 最后输出一个概率
output = F.log_softmax(self.out(output[0]), dim=1) # use log_softmax for nllloss
return output, hidden, attn_weights
def initHidden(self, batch_size):
......@@ -131,46 +132,6 @@ class Attentiondecoder(nn.Module):
return result
# def __init__(self, input_size, hidden_size, num_classes, num_embeddings=128):
# super(Attention, self).__init__()
# self.attention_cell = AttentionCell(input_size, hidden_size, num_embeddings)
# self.input_size = input_size
# self.hidden_size = hidden_size
# self.generator = nn.Linear(hidden_size, num_classes)
# self.char_embeddings = Parameter(torch.randn(num_classes+1, num_embeddings))
# self.num_embeddings = num_embeddings
# self.processed_batches = 0
# # targets is nT * nB
# def forward(self, feats, text_length, text):
# # target_txt_decode
# targets =target_txt_decode(batch_size, text_length, text)
# output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
# hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
# max_locs = torch.zeros(num_steps, nB)
# max_vals = torch.zeros(num_steps, nB)
# for i in range(num_steps):
# cur_embeddings = self.char_embeddings.index_select(0, targets[i])
# hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings)
# output_hiddens[i] = hidden
# if self.processed_batches % 500 == 0:
# max_val, max_loc = alpha.data.max(1)
# max_locs[i] = max_loc.cpu()
# max_vals[i] = max_val.cpu()
# if self.processed_batches % 500 == 0:
# print('max_locs', list(max_locs[0:text_length.data[0],0]))
# print('max_vals', list(max_vals[0:text_length.data[0],0]))
# new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
# b = 0
# start = 0
# for length in text_length.data:
# new_hiddens[start:start+length] = output_hiddens[0:length,b,:]
# start = start + length
# b = b + 1
# probs = self.generator(new_hiddens)
# return probs
def target_txt_decode(batch_size, text_length, text):
'''
......@@ -246,13 +207,70 @@ class decoder(nn.Module):
return result
# target_variable = target_txt_decode(b, length, text)
# decoder_input = torch.zeros(b).long().cuda() # 初始化decoder的开始,从0开始输出
# decoder_hidden = self.decoder.initHidden(b).cuda()
# if self.teach_forcing:
# # 教师强制:将目标label作为下一个输入
# for di in range(target_variable.shape[0]): # 最大字符串的长度
# decoder_output, decoder_hidden, decoder_attention = self.decoder(
# decoder_input, decoder_hidden, encoder_outputs)
# loss += criterion(decoder_output, target_variable[di]) # 每次预测一个字符
# decoder_input = target_variable[di] # Teacher forcing/前一次的输出
\ No newline at end of file
class AttentiondecoderV2(nn.Module):
"""
采用seq to seq模型,修改注意力权重的计算方式
"""
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttentiondecoderV2, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
# test
self.vat = nn.Linear(hidden_size, 1)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input) # 前一次的输出进行词嵌入
embedded = self.dropout(embedded)
# test
batch_size = encoder_outputs.shape[1]
alpha = hidden + encoder_outputs # 特征融合采用+/concat其实都可以
alpha = alpha.view(-1, alpha.shape[-1])
attn_weights = self.vat( torch.tanh(alpha)) # 将encoder_output:batch*seq*features,将features的维度降为1
attn_weights = attn_weights.view(-1, 1, batch_size).permute((2,1,0))
attn_weights = F.softmax(attn_weights, dim=2)
# attn_weights = F.softmax(
# self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1) # 上一次的输出和隐藏状态求出权重
attn_applied = torch.matmul(attn_weights,
encoder_outputs.permute((1, 0, 2))) # 矩阵乘法,bmm(8×1×56,8×56×256)=8×1×256
output = torch.cat((embedded, attn_applied.squeeze(1) ), 1) # 上一次的输出和attention feature,做一个线性+GRU
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1) # 最后输出一个概率
return output, hidden, attn_weights
def initHidden(self, batch_size):
result = Variable(torch.zeros(1, batch_size, self.hidden_size))
return result
class decoderV2(nn.Module):
'''
decoder from image features
'''
def __init__(self, nh=256, nclass=13, dropout_p=0.1):
super(decoderV2, self).__init__()
self.hidden_size = nh
self.decoder = AttentiondecoderV2(nh, nclass, dropout_p)
def forward(self, input, hidden, encoder_outputs):
return self.decoder(input, hidden, encoder_outputs)
def initHidden(self, batch_size):
result = Variable(torch.zeros(1, batch_size, self.hidden_size))
return result
......@@ -8,20 +8,20 @@ import torch.optim as optim
import torch.utils.data
import numpy as np
import os
import utils
import dataset
import src.utils as utils
import src.dataset as dataset
import time
import torch.nn as nn
from utils import alphabet
from src.utils import alphabet
from src.utils import weights_init
import models.crnn_lang as crnn
print(crnn.__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--trainlist', default='')
parser.add_argument('--vallist', default='')
parser.add_argument('--trainlist', default='./data/ch_train.txt')
parser.add_argument('--vallist', default='./data/ch_test.txt')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=280, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
......@@ -85,21 +85,12 @@ nc = 1
converter = utils.strLabelConverterForAttention(alphabet)
# criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.NLLLoss()
def weights_init(model):
# Official init from torch repo.
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
criterion = torch.nn.NLLLoss() # 最后的输出要为log_softmax
encoder = crnn.CNN(opt.imgH, nc, opt.nh)
decoder = crnn.decoder(opt.nh, nclass, dropout_p=0.1, max_length=opt.max_width) # max_length:w/4,为encoder特征提取之后宽度方向上的序列长度
# decoder = crnn.decoder(opt.nh, nclass, dropout_p=0.1, max_length=opt.max_width) # max_length:w/4,为encoder特征提取之后宽度方向上的序列长度
decoder = crnn.decoderV2(opt.nh, nclass, dropout_p=0.1) # For prediction of an indefinite long sequence
encoder.apply(weights_init)
decoder.apply(weights_init)
# continue training or use the pretrained model to initial the parameters of the encoder and decoder
......@@ -222,7 +213,7 @@ def trainBatch(encoder, decoder, criterion, encoder_optimizer, decoder_optimizer
encoder_outputs = encoder(image) # cnn+biLstm做特征提取
target_variable = target_variable.cuda()
decoder_input = target_variable[0].cuda() # 初始化decoder的开始,从0开始输出
decoder_input = target_variable[0].cuda() # 初始化decoder的开始,从0开始输出
decoder_hidden = decoder.initHidden(b).cuda()
loss = 0.0
teach_forcing = True if random.random() > teach_forcing_prob else False
......
#!/usr/bin/python
# encoding: utf-8
import torch
import torch.nn as nn
from torch.autograd import Variable
import collections
from PIL import Image, ImageFilter
import math
import random
import numpy as np
import cv2
with open('./data/char_std_5990.txt') as f:
data = f.readlines()
alphabet = [x.rstrip() for x in data]
alphabet = ''.join(alphabet)
class strLabelConverterForAttention(object):
"""Convert between str and label.
NOTE:
Insert `EOS` to the alphabet for attention.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet):
self.alphabet = alphabet
self.dict = {}
self.dict['SOS'] = 0 # 开始
self.dict['EOS'] = 1 # 结束
self.dict['$'] = 2 # blank标识符
for i, item in enumerate(self.alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[item] = i + 3 # 从3开始编码
def encode(self, text):
"""对target_label做编码和对齐
对target txt每个字符串的开始加上GO,最后加上EOS,并用最长的字符串做对齐
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor targets:max_length × batch_size
"""
if isinstance(text, str):
text = [self.dict[item] for item in text]
elif isinstance(text, collections.Iterable):
text = [self.encode(s) for s in text] # 编码
max_length = max([len(x) for x in text]) # 对齐
nb = len(text)
targets = torch.ones(nb, max_length + 2) * 2 # use ‘blank’ for pading
for i in range(nb):
targets[i][0] = 0 # 开始
targets[i][1:len(text[i]) + 1] = text[i]
targets[i][len(text[i]) + 1] = 1
text = targets.transpose(0, 1).contiguous()
text = text.long()
return torch.LongTensor(text)
def decode(self, t):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
texts = list(self.dict.keys())[list(self.dict.values()).index(t)]
return texts
class strLabelConverterForCTC(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, sep):
self.sep = sep
self.alphabet = alphabet.split(sep)
self.alphabet.append('-') # for `-1` index
self.dict = {}
for i, item in enumerate(self.alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[item] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = text.split(self.sep)
text = [self.dict[item] for item in text]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s.split(self.sep)) for s in text]
text = self.sep.join(text)
text, _ = self.encode(text)
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
class averager(object):
"""Compute average for `torch.Variable` and `torch.Tensor`. """
def __init__(self):
self.reset()
def add(self, v):
if isinstance(v, Variable):
count = v.data.numel()
v = v.data.sum()
elif isinstance(v, torch.Tensor):
count = v.numel()
v = v.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
def oneHot(v, v_length, nc):
batchSize = v_length.size(0)
maxLength = v_length.max()
v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
acc = 0
for i in range(batchSize):
length = v_length[i]
label = v[acc:acc + length].view(-1, 1).long()
v_onehot[i, :length].scatter_(1, label, 1.0)
acc += length
return v_onehot
def loadData(v, data):
v.data.resize_(data.size()).copy_(data)
def prettyPrint(v):
print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0],
v.mean().data[0]))
def assureRatio(img):
"""Ensure imgH <= imgW."""
b, c, h, w = img.size()
if h > w:
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img
class halo():
'''
u:高斯分布的均值
sigma:方差
nums:在一张图片中随机添加几个光点
prob:使用halo的概率
'''
def __init__(self, nums, u=0, sigma=0.2, prob=0.5):
self.u = u # 均值μ
self.sig = math.sqrt(sigma) # 标准差δ
self.nums = nums
self.prob = prob
def create_kernel(self, maxh=32, maxw=50):
height_scope = [10, maxh] # 高度范围 随机生成高斯
weight_scope = [20, maxw] # 宽度范围
x = np.linspace(self.u - 3 * self.sig, self.u + 3 * self.sig, random.randint(*height_scope))
y = np.linspace(self.u - 3 * self.sig, self.u + 3 * self.sig, random.randint(*weight_scope))
Gauss_map = np.zeros((len(x), len(y)))
for i in range(len(x)):
for j in range(len(y)):
Gauss_map[i, j] = np.exp(-((x[i] - self.u) ** 2 + (y[j] - self.u) ** 2) / (2 * self.sig ** 2)) / (
math.sqrt(2 * math.pi) * self.sig)
return Gauss_map
def __call__(self, img):
if random.random() < self.prob:
Gauss_map = self.create_kernel(32, 60) # 初始化一个高斯核,32为高度方向的最大值,60为w方向
img1 = np.asarray(img)
img1.flags.writeable = True # 将数组改为读写模式
nums = random.randint(1, self.nums) # 随机生成nums个光点
img1 = img1.astype(np.float)
# print(nums)
for i in range(nums):
img_h, img_w = img1.shape
pointx = random.randint(0, img_h - 10) # 在原图中随机找一个点
pointy = random.randint(0, img_w - 10)
h, w = Gauss_map.shape # 判断是否超限
endx = pointx + h
endy = pointy + w
if pointx + h > img_h:
endx = img_h
Gauss_map = Gauss_map[1:img_h - pointx + 1, :]
if img_w < pointy + w:
endy = img_w
Gauss_map = Gauss_map[:, 1:img_w - pointy + 1]
# 加上不均匀光照
img1[pointx:endx, pointy:endy] = img1[pointx:endx, pointy:endy] + Gauss_map * 255.0
img1[img1 > 255.0] = 255.0 # 进行限幅,不然uint8会从0开始重新计数
img = img1
return Image.fromarray(np.uint8(img))
class MyGaussianBlur(ImageFilter.Filter):
name = "GaussianBlur"
def __init__(self, radius=2, bounds=None):
self.radius = radius
self.bounds = bounds
def filter(self, image):
if self.bounds:
clips = image.crop(self.bounds).gaussian_blur(self.radius)
image.paste(clips, self.bounds)
return image
else:
return image.gaussian_blur(self.radius)
class GBlur(object):
def __init__(self, radius=2, prob=0.5):
radius = random.randint(0, radius)
self.blur = MyGaussianBlur(radius=radius)
self.prob = prob
def __call__(self, img):
if random.random() < self.prob:
img = img.filter(self.blur)
return img
class RandomBrightness(object):
"""随机改变亮度
pil:pil格式的图片
"""
def __init__(self, prob=1.5):
self.prob = prob
def __call__(self, pil):
rgb = np.asarray(pil)
if random.random() < self.prob:
hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)
h, s, v = cv2.split(hsv)
adjust = random.choice([0.5, 0.7, 0.9, 1.2, 1.5, 1.7]) # 随机选择一个
# adjust = random.choice([1.2, 1.5, 1.7, 2.0]) # 随机选择一个
v = v * adjust
v = np.clip(v, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h, s, v))
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(np.uint8(rgb)).convert('L')
class randapply(object):
"""随机决定是否应用光晕、模糊或者二者都用
Args:
transforms (list or tuple): list of transformations
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册