提交 69ffc623 编写于 作者: 农夫山泉2号's avatar 农夫山泉2号

attention ocr.pytorch

上级 daea81b9
......@@ -4,4 +4,6 @@
*.pyo
*.log
*.tmp
.idea/
.vscode/
expr/attention2/
Robust Scene Text Recognition with Automatic Rectification
attention-ocr.pytorch:Encoder+Decoder+attention model
======================================
This repository implements the Robust Scene Text Recognition with Automatic Rectification (SRN only) in pytorch, which is modified from https://github.com/meijieru/crnn.pytorch
This repository implements the the encoder and decoder model with attention model for OCR, and this repository is modified from https://github.com/meijieru/crnn.pytorch
There are many attention ocr repository which is finished with tensorflow, but they don't give the **inference.py**, besides i'm not good at tensorflow, i can't finish the **inference.py** by myself
Train for VGG text data
--------------
[download dataset](http://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz)
1. create a link to mnt folder
2. python data/create_mnt_list.py
3. python main.py --trainlist data/train_list.txt --vallist data/test_list.txt --cuda --adam --lr=0.001
# requirements
```
pytorch 0.4.1
```
# Test
1. change the parameters of the **demo.py***
2.
```bash
python demo.py
```
3. results
```
>>>predict_str:87635 => prob:0.8684815168380737
```
# Train Your Owm Model
there are some details for attention
1. training and inferencing the width of image must be the same, in my project, i pad all the image's width to 220
2. **decoder(opt.nh, nclass, dropout_p=0.1, max_length=56)**, 'max_length' is the feature's width from encoder(change with the imgW)
3. for batch training,i pad the target label for the same length, and i encode the alphabet start from 3, 0 for SOS, 1 for EOS, 2 for $(means others)
4. the train_list.txt and test_list.txt are created as the follow form:
```
# path/to/image_name.jpg label
/media/chenjun/ed/18_MechanicalCrnn/data/mechanical/imgs/4667.jpg 99996
/media/chenjun/ed/18_MechanicalCrnn/data/mechanical/imgs/0985.jpg 81309
```
# Reference
1. [crnn.pytorch](https://github.com/meijieru/crnn.pytorch)
2. [Attention-OCR](https://github.com/da03/Attention-OCR)
3. [Seq2Seq-PyTorch](https://github.com/MaximumEntropy/Seq2Seq-PyTorch)
# TO DO
- [ ] change LSTM to Conv1D, it can greatly accelerate the inference
- [ ] to support images of different widths
# Other
I am now working in a company in chengdu, using deep learning to do image-related work. But the department is just established, no technical accumulation, the work is very difficult. So now I want to change a job. The place of work is either chengdu or chongqing. If there is a way, please help me push it internally. Thank you very much.
本人现在在成都的一家公司,职位:图像识别算法工程。但是部门刚成立,招的都是应届生,没有技术积累。所以现在想换一份工作,做computer vision方向的,工作地点在成都或者重庆都行,有途径也请帮忙内推一下。本人练习方式:778961303@qq.com.非常感谢。
with open('data/mnt/ramdisk/max/90kDICT32px/annotation_train.txt') as fp:
lines = fp.readlines()
train_fp = open('data/train_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = label + '$'
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
print >> train_fp, output
train_fp.close()
with open('data/mnt/ramdisk/max/90kDICT32px/annotation_test.txt') as fp:
lines = fp.readlines()
test_fp = open('data/test_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = label + '$'
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
print >> test_fp, output
test_fp.close()
with open('data/mnt/ramdisk/max/90kDICT32px/annotation_test.txt') as fp:
lines = fp.readlines()
val_fp = open('data/val_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = label + '$'
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
print >> val_fp, output
val_fp.close()
因为 它太大了无法显示 source diff 。你可以改为 查看blob
此差异已折叠。
......@@ -6,7 +6,7 @@ import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
import lmdb
# import lmdb
import six
import sys
from PIL import Image
......@@ -15,6 +15,7 @@ import numpy as np
class listDataset(Dataset):
def __init__(self, list_file=None, transform=None, target_transform=None):
self.list_file = list_file
with open(list_file) as fp:
self.lines = fp.readlines()
self.nSamples = len(self.lines)
......@@ -31,7 +32,10 @@ class listDataset(Dataset):
line_splits = self.lines[index].strip().split(' ')
imgpath = line_splits[0]
try:
img = Image.open(imgpath).convert('L')
if 'train' in self.list_file:
img = Image.open(imgpath)
else:
img = Image.open(imgpath).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
......
# coding:utf-8
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
import models.crnn as crnn
import models.crnn_lang as crnn
use_gpu = True
model_path = './data/crnn.pth'
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
encoder_path = './expr/attentioncnn/encoder_600.pth'
decoder_path = './expr/attentioncnn/decoder_600.pth'
img_path = './test/0003.jpg'
alphabet = '0123456789'
max_length = 7 # 最长字符串的长度
EOS_TOKEN = 1
model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
model = model.cuda()
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))
nclass = len(alphabet) + 3
encoder = crnn.CNN(32, 1, 256) # 编码器
decoder = crnn.decoder(256, nclass) # seq to seq的解码器, nclass在decoder中还加了2
converter = utils.strLabelConverter(alphabet)
if encoder_path and decoder_path:
print('loading pretrained models ......')
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))
if torch.cuda.is_available() and use_gpu:
encoder = encoder.cuda()
decoder = decoder.cuda()
transformer = dataset.resizeNormalize((100, 32))
converter = utils.strLabelConverterForAttention(alphabet)
transformer = dataset.resizeNormalize((220, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():
if torch.cuda.is_available() and use_gpu:
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
encoder.eval()
decoder.eval()
encoder_out = encoder(image)
_, preds = preds.max(2)
preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
decoded_words = []
prob = 1.0
decoder_attentions = torch.zeros(max_length, 56)
decoder_input = torch.zeros(1).long() # 初始化decoder的开始,从0开始输出
decoder_hidden = decoder.initHidden(1)
if torch.cuda.is_available() and use_gpu:
decoder_input = decoder_input.cuda()
decoder_hidden = decoder_hidden.cuda()
loss = 0.0
# 预测的时候采用非强制策略,将前一次的输出,作为下一次的输入,直到标签为EOS_TOKEN时停止
for di in range(max_length): # 最大字符串的长度
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_out)
probs = torch.exp(decoder_output)
decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
ni = topi.squeeze(1)
decoder_input = ni
prob *= probs[:, ni]
if ni == EOS_TOKEN:
# decoded_words.append('<EOS>')
break
else:
decoded_words.append(converter.decode(ni))
preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('%-20s => %-20s' % (raw_pred, sim_pred))
words = ''.join(decoded_words)
prob = prob.item()
print('predict_str:%-20s => prob:%-20s' % (words, prob))
from __future__ import print_function
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
from warpctc_pytorch import CTCLoss
import os
import utils
import dataset
import time
import models.crnn_lang as crnn
print(crnn.__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--trainlist', required=True, help='path to train_list')
parser.add_argument('--vallist', required=True, help='path to val_list')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--niter', type=int, default=1000, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, default=0.00005')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
parser.add_argument('--alphabet', type=str, default='0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$')
parser.add_argument('--sep', type=str, default=':')
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=10000, help='Interval to be displayed')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--lang', action='store_true', help='whether to use char language model')
parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler')
opt = parser.parse_args()
print(opt)
if opt.experiment is None:
opt.experiment = 'expr'
os.system('mkdir {0}'.format(opt.experiment))
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
train_dataset = dataset.listDataset(list_file =opt.trainlist)
assert train_dataset
if not opt.random_sample:
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=False, sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.resizeNormalize((100, 32)))
nclass = len(opt.alphabet.split(opt.sep))
nc = 1
converter = utils.strLabelConverterForAttention(opt.alphabet, opt.sep)
criterion = torch.nn.CrossEntropyLoss()
# custom weights initialization called on crnn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.crnn != '':
print('loading pretrained model from %s' % opt.crnn)
crnn.load_state_dict(torch.load(opt.crnn))
print(crnn)
image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
text = torch.LongTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)
if opt.cuda:
crnn.cuda()
crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
image = image.cuda()
text = text.cuda()
criterion = criterion.cuda()
image = Variable(image)
text = Variable(text)
length = Variable(length)
# loss averager
loss_avg = utils.averager()
# setup optimizer
if opt.adam:
optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
betas=(opt.beta1, 0.999))
elif opt.adadelta:
optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
else:
optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)
def val(net, dataset, criterion, max_iter=100):
print('Start val')
for p in crnn.parameters():
p.requires_grad = False
net.eval()
data_loader = torch.utils.data.DataLoader(
dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
val_iter = iter(data_loader)
i = 0
n_correct = 0
loss_avg = utils.averager()
max_iter = min(max_iter, len(data_loader))
for i in range(max_iter):
data = val_iter.next()
i += 1
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
if opt.lang:
preds = crnn(image, length, text)
else:
preds = crnn(image, length)
cost = criterion(preds, text)
loss_avg.add(cost)
_, preds = preds.max(1)
preds = preds.view(-1)
sim_preds = converter.decode(preds.data, length.data)
for pred, target in zip(sim_preds, cpu_texts):
target = ''.join(target.split(opt.sep))
if pred == target:
n_correct += 1
for pred, gt in zip(sim_preds, cpu_texts):
gt = ''.join(gt.split(opt.sep))
print('%-20s, gt: %-20s' % (pred, gt))
accuracy = n_correct / float(max_iter * opt.batchSize)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
def trainBatch(net, criterion, optimizer):
data = train_iter.next()
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
if opt.lang:
preds = crnn(image, length, text)
else:
preds = crnn(image, length)
cost = criterion(preds, text)
crnn.zero_grad()
cost.backward()
optimizer.step()
return cost
t0 = time.time()
for epoch in range(opt.niter):
train_iter = iter(train_loader)
i = 0
while i < len(train_loader):
for p in crnn.parameters():
p.requires_grad = True
crnn.train()
cost = trainBatch(crnn, criterion, optimizer)
loss_avg.add(cost)
i += 1
if i % opt.displayInterval == 0:
print('[%d/%d][%d/%d] Loss: %f' %
(epoch, opt.niter, i, len(train_loader), loss_avg.val()))
loss_avg.reset()
t1 = time.time()
print('time elapsed %d' % (t1-t0))
t0 = time.time()
if i % opt.valInterval == 0:
val(crnn, test_dataset, criterion)
# do checkpointing
if i % opt.saveInterval == 0:
torch.save(
crnn.module.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
# coding:utf-8
from __future__ import print_function
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
import os
import utils
import dataset
import time
import torch.nn as nn
from utils import randapply, halo, RandomBrightness, GBlur
import models.crnn_lang as crnn
print(crnn.__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--trainlist', default='./data/remote_train_list.txt')
parser.add_argument('--vallist', default='./data/remote_test_list.txt')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=220, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--niter', type=int, default=601, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate for Critic, default=0.00005')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda', default=True)
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--encoder', type=str, default='./expr/attentioncnn/encoder_600.pth', help="path to encoder (to continue training)")
parser.add_argument('--decoder', type=str, default='./expr/attentioncnn/decoder_600.pth', help='path to decoder (to continue training)')
parser.add_argument('--alphabet', type=str, default='0123456789', help='$ means blank for seq to seq')
parser.add_argument('--experiment', default='./expr/attentioncnn', help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=10, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=100, help='Interval to be displayed')
parser.add_argument('--adam', default=True, action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--lang', default=True, action='store_true', help='whether to use char language model')
parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler')
parser.add_argument('--teaching_forcing_prob', type=float, default=0.5, help='where to use teach forcing')
opt = parser.parse_args()
print(opt)
SOS_token = 0
EOS_TOKEN = 1 # 结束标志的标签
BLANK = 2 # blank for padding
if opt.experiment is None:
opt.experiment = 'expr'
os.system('mkdir -p {0}'.format(opt.experiment)) # 创建多级目录
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
transform = randapply([ # 随机决定是否用预处理或者用哪种预处理
RandomBrightness(prob=0.5), # 增加图片的亮度
halo(nums=3), # 添加光晕
GBlur(radius=2) # 高斯模糊
])
train_dataset = dataset.listDataset(list_file =opt.trainlist, transform=transform)
assert train_dataset
if not opt.random_sample:
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=False, sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.resizeNormalize((220, 32)))
nclass = len(opt.alphabet) + 3 # decoder的时候,需要的类别数,3 for SOS,EOS和blank
nc = 1
converter = utils.strLabelConverterForAttention(opt.alphabet)
# criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.NLLLoss()
def weights_init(model):
# Official init from torch repo.
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
encoder = crnn.CNN(opt.imgH, nc, opt.nh)
decoder = crnn.decoder(opt.nh, nclass, dropout_p=0.1, max_length=56) # max_length:w/4,为encoder特征提取之后宽度方向上的序列长度
encoder.apply(weights_init)
decoder.apply(weights_init)
# continue training or use the pretrained model to initial the parameters of the encoder and decoder
if opt.encoder:
print('loading pretrained encoder model from %s' % opt.encoder)
encoder.load_state_dict(torch.load(opt.encoder))
if opt.decoder:
print('loading pretrained encoder model from %s' % opt.decoder)
encoder.load_state_dict(torch.load(opt.encoder))
print(encoder)
print(decoder)
image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
text = torch.LongTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)
if opt.cuda:
encoder.cuda()
decoder.cuda()
# encoder = torch.nn.DataParallel(encoder, device_ids=range(opt.ngpu))
# decoder = torch.nn.DataParallel(decoder, device_ids=range(opt.ngpu))
image = image.cuda()
text = text.cuda()
criterion = criterion.cuda()
# loss averager
loss_avg = utils.averager()
# setup optimizer
if opt.adam:
encoder_optimizer = optim.Adam(encoder.parameters(), lr=opt.lr,
betas=(opt.beta1, 0.999))
decoder_optimizer = optim.Adam(decoder.parameters(), lr=opt.lr,
betas=(opt.beta1, 0.999))
elif opt.adadelta:
optimizer = optim.Adadelta(encoder.parameters(), lr=opt.lr)
else:
optimizer = optim.RMSprop(encoder.parameters(), lr=opt.lr)
def val(encoder, decoder, criterion, batchsize, dataset, teach_forcing=False, max_iter=100):
print('Start val')
for e, d in zip(encoder.parameters(), decoder.parameters()):
e.requires_grad = False
d.requires_grad = False
encoder.eval()
decoder.eval()
data_loader = torch.utils.data.DataLoader(
dataset, shuffle=True, batch_size=batchsize, num_workers=int(opt.workers))
val_iter = iter(data_loader)
n_correct = 0
n_total = 0
loss_avg = utils.averager()
max_iter = min(max_iter, len(data_loader))
# max_iter = len(data_loader) - 1
for i in range(max_iter):
data = val_iter.next()
i += 1
cpu_images, cpu_texts = data
b = cpu_images.size(0)
utils.loadData(image, cpu_images)
target_variable = converter.encode(cpu_texts)
n_total += len(cpu_texts[0]) + 1 # 还要准确预测出EOS停止位
decoded_words = []
decoded_label = []
decoder_attentions = torch.zeros(len(cpu_texts[0]) + 1, 56)
encoder_outputs = encoder(image) # cnn+biLstm做特征提取
target_variable = target_variable.cuda()
decoder_input = target_variable[0].cuda() # 初始化decoder的开始,从0开始输出
decoder_hidden = decoder.initHidden(b).cuda()
loss = 0.0
if not teach_forcing:
# 预测的时候采用非强制策略,将前一次的输出,作为下一次的输入,直到标签为EOS_TOKEN时停止
for di in range(1, target_variable.shape[0]): # 最大字符串的长度
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
loss += criterion(decoder_output, target_variable[di]) # 每次预测一个字符
loss_avg.add(loss)
decoder_attentions[di-1] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
ni = topi.squeeze(1)
decoder_input = ni
if ni == EOS_TOKEN:
decoded_words.append('<EOS>')
decoded_label.append(EOS_TOKEN)
break
else:
decoded_words.append(converter.decode(ni))
decoded_label.append(ni)
# 计算正确个数
for pred, target in zip(decoded_label, target_variable[1:,:]):
if pred == target:
n_correct += 1
if i % 100 == 0: # 每100次输出一次
texts = cpu_texts[0]
print('pred:%-20s, gt: %-20s' % (decoded_words, texts))
accuracy = n_correct / float(n_total)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
def trainBatch(encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, teach_forcing_prob=1):
'''
target_label:采用后处理的方式,进行编码和对齐,以便进行batch训练
'''
data = train_iter.next()
cpu_images, cpu_texts = data
b = cpu_images.size(0)
target_variable = converter.encode(cpu_texts)
utils.loadData(image, cpu_images)
encoder_outputs = encoder(image) # cnn+biLstm做特征提取
target_variable = target_variable.cuda()
decoder_input = target_variable[0].cuda() # 初始化decoder的开始,从0开始输出
decoder_hidden = decoder.initHidden(b).cuda()
loss = 0.0
teach_forcing = True if random.random() > teach_forcing_prob else False
if teach_forcing:
# 教师强制:将目标label作为下一个输入
for di in range(1, target_variable.shape[0]): # 最大字符串的长度
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
loss += criterion(decoder_output, target_variable[di]) # 每次预测一个字符
decoder_input = target_variable[di] # Teacher forcing/前一次的输出
else:
for di in range(1, target_variable.shape[0]):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
loss += criterion(decoder_output, target_variable[di]) # 每次预测一个字符
topv, topi = decoder_output.data.topk(1)
ni = topi.squeeze()
decoder_input = ni
encoder.zero_grad()
decoder.zero_grad()
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
return loss
if __name__ == '__main__':
t0 = time.time()
for epoch in range(opt.niter):
train_iter = iter(train_loader)
i = 0
while i < len(train_loader)-1:
for e, d in zip(encoder.parameters(), decoder.parameters()):
e.requires_grad = True
d.requires_grad = True
encoder.train()
decoder.train()
cost = trainBatch(encoder, decoder, criterion, encoder_optimizer,
decoder_optimizer, teach_forcing_prob=opt.teaching_forcing_prob)
loss_avg.add(cost)
i += 1
if i % opt.displayInterval == 0:
print('[%d/%d][%d/%d] Loss: %f' %
(epoch, opt.niter, i, len(train_loader), loss_avg.val()), end=' ')
loss_avg.reset()
t1 = time.time()
print('time elapsed %d' % (t1-t0))
t0 = time.time()
# do checkpointing
if epoch % opt.saveInterval == 0:
val(encoder, decoder, criterion, 1, dataset=test_dataset, teach_forcing=False) # batchsize:1
torch.save(
encoder.state_dict(), '{0}/encoder_{1}.pth'.format(opt.experiment, epoch))
torch.save(
decoder.state_dict(), '{0}/decoder_{1}.pth'.format(opt.experiment, epoch))
\ No newline at end of file
......@@ -43,8 +43,8 @@ class AttentionCell(nn.Module):
feats_proj = self.i2h(feats.view(-1,nC))
prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size)
emition = self.score(F.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
alpha = F.softmax(emition) # nB * nT
emition = self.score(torch.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
alpha = F.softmax(emition, dim=1) # nB * nT
if self.processed_batches % 10000 == 0:
print('emition ', list(emition.data[0]))
......
# coding:utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
GO = 0
EOS_TOKEN = 1 # 结束标志的标签
class BidirectionalLSTM(nn.Module):
......@@ -22,6 +25,7 @@ class BidirectionalLSTM(nn.Module):
output = output.view(T, b, -1)
return output
class AttentionCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings=128):
......@@ -33,7 +37,7 @@ class AttentionCell(nn.Module):
self.hidden_size = hidden_size
self.input_size = input_size
self.num_embeddings = num_embeddings
self.processed_batches = 0
self.processed_batches = 0
def forward(self, prev_hidden, feats, cur_embeddings):
nT = feats.size(0)
......@@ -44,7 +48,7 @@ class AttentionCell(nn.Module):
feats_proj = self.i2h(feats.view(-1,nC))
prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size)
emition = self.score(F.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
emition = self.score(torch.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
self.processed_batches = self.processed_batches + 1
if self.processed_batches % 10000 == 0:
......@@ -54,71 +58,149 @@ class AttentionCell(nn.Module):
if self.processed_batches % 10000 == 0:
print('emition ', list(emition.data[0]))
print('alpha ', list(alpha.data[0]))
context = (feats * alpha.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0) # nB * nC
context = (feats * alpha.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0) # nB * nC//感觉不应该sum,输出4×256
context = torch.cat([context, cur_embeddings], 1)
cur_hidden = self.rnn(context, prev_hidden)
return cur_hidden, alpha
class Attention(nn.Module):
def __init__(self, input_size, hidden_size, num_classes, num_embeddings=128):
super(Attention, self).__init__()
self.attention_cell = AttentionCell(input_size, hidden_size, num_embeddings)
self.input_size = input_size
class DecoderRNN(nn.Module):
"""
采用RNN进行解码
"""
def __init__(self, hidden_size, output_size):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.generator = nn.Linear(hidden_size, num_classes)
self.char_embeddings = Parameter(torch.randn(num_classes+1, num_embeddings))
self.num_embeddings = num_embeddings
self.processed_batches = 0
# targets is nT * nB
def forward(self, feats, text_length, text):
self.processed_batches = self.processed_batches + 1
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
hidden_size = self.hidden_size
input_size = self.input_size
assert(input_size == nC)
assert(nB == text_length.numel())
num_steps = text_length.data.max()
num_labels = text_length.data.sum()
targets = torch.zeros(nB, num_steps+1).long().cuda()
start_id = 0
for i in range(nB):
targets[i][1:1+text_length.data[i]] = text.data[start_id:start_id+text_length.data[i]]+1
start_id = start_id+text_length.data[i]
targets = Variable(targets.transpose(0,1).contiguous())
output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
max_locs = torch.zeros(num_steps, nB)
max_vals = torch.zeros(num_steps, nB)
for i in range(num_steps):
cur_embeddings = self.char_embeddings.index_select(0, targets[i])
hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings)
output_hiddens[i] = hidden
if self.processed_batches % 500 == 0:
max_val, max_loc = alpha.data.max(1)
max_locs[i] = max_loc.cpu()
max_vals[i] = max_val.cpu()
if self.processed_batches % 500 == 0:
print('max_locs', list(max_locs[0:text_length.data[0],0]))
print('max_vals', list(max_vals[0:text_length.data[0],0]))
new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
b = 0
start = 0
for length in text_length.data:
new_hiddens[start:start+length] = output_hiddens[0:length,b,:]
start = start + length
b = b + 1
probs = self.generator(new_hiddens)
return probs
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
self.embedding = nn.Embedding(output_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
output = self.embedding(input).view(1, 1, -1)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden
def initHidden(self):
result = Variable(torch.zeros(1, 1, self.hidden_size))
return result
class Attentiondecoder(nn.Module):
"""
采用attention注意力机制,进行解码
"""
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=56):
super(Attentiondecoder, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input) # 前一次的输出进行词嵌入
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1) # 上一次的输出和隐藏状态求出权重
attn_applied = torch.matmul(attn_weights.unsqueeze(1),
encoder_outputs.permute((1, 0, 2))) # 矩阵乘法,bmm(8×1×56,8×56×256)=8×1×256
output = torch.cat((embedded, attn_applied.squeeze(1) ), 1) # 上一次的输出和attention feature,做一个线性+GRU
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1) # 最后输出一个概率
return output, hidden, attn_weights
def initHidden(self, batch_size):
result = Variable(torch.zeros(1, batch_size, self.hidden_size))
return result
# def __init__(self, input_size, hidden_size, num_classes, num_embeddings=128):
# super(Attention, self).__init__()
# self.attention_cell = AttentionCell(input_size, hidden_size, num_embeddings)
# self.input_size = input_size
# self.hidden_size = hidden_size
# self.generator = nn.Linear(hidden_size, num_classes)
# self.char_embeddings = Parameter(torch.randn(num_classes+1, num_embeddings))
# self.num_embeddings = num_embeddings
# self.processed_batches = 0
# # targets is nT * nB
# def forward(self, feats, text_length, text):
# # target_txt_decode
# targets =target_txt_decode(batch_size, text_length, text)
# output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
# hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
# max_locs = torch.zeros(num_steps, nB)
# max_vals = torch.zeros(num_steps, nB)
# for i in range(num_steps):
# cur_embeddings = self.char_embeddings.index_select(0, targets[i])
# hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings)
# output_hiddens[i] = hidden
# if self.processed_batches % 500 == 0:
# max_val, max_loc = alpha.data.max(1)
# max_locs[i] = max_loc.cpu()
# max_vals[i] = max_val.cpu()
# if self.processed_batches % 500 == 0:
# print('max_locs', list(max_locs[0:text_length.data[0],0]))
# print('max_vals', list(max_vals[0:text_length.data[0],0]))
# new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
# b = 0
# start = 0
# for length in text_length.data:
# new_hiddens[start:start+length] = output_hiddens[0:length,b,:]
# start = start + length
# b = b + 1
# probs = self.generator(new_hiddens)
# return probs
def target_txt_decode(batch_size, text_length, text):
'''
对target txt每个字符串的开始加上GO,最后加上EOS,并用最长的字符串做对齐
return:
targets: num_steps+1 * batch_size
'''
nB = batch_size # batch
# 将text分离出来
num_steps = text_length.data.max()
num_steps = int(num_steps.cpu().numpy())
targets = torch.ones(nB, num_steps + 2) * 2 # 用$符号填充较短的字符串, 在最开始加上GO,结束加上EOS_TOKEN
targets = targets.long().cuda() # 用
start_id = 0
for i in range(nB):
targets[i][0] = GO # 在开始的加上开始标签
targets[i][1:text_length.data[i] + 1] = text.data[start_id:start_id+text_length.data[i]] # 是否要加1
targets[i][text_length.data[i] + 1] = EOS_TOKEN # 加上结束标签
start_id = start_id+text_length.data[i] # 拆分每个目标的target label,为:batch×最长字符的numel
targets = Variable(targets.transpose(0, 1).contiguous())
return targets
class CNN(nn.Module):
'''
CNN+BiLstm做特征提取
'''
def __init__(self, imgH, nc, nh):
super(CNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
self.cnn = nn.Sequential(
......@@ -129,13 +211,11 @@ class CRNN(nn.Module):
nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True), # 512x4x25
nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)), # 512x2x25
nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True)) # 512x1x25
#self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nh))
self.attention = Attention(nh, nh, nclass, 256)
def forward(self, input, length, text):
def forward(self, input):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
......@@ -143,8 +223,36 @@ class CRNN(nn.Module):
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
rnn = self.rnn(conv)
output = self.attention(rnn, length, text)
# rnn features calculate
encoder_outputs = self.rnn(conv) # seq * batch * n_classes// 25 × batchsize × 256(隐藏节点个数)
return encoder_outputs
return output
class decoder(nn.Module):
'''
decoder from image features
'''
def __init__(self, nh=256, nclass=13, dropout_p=0.1, max_length=56):
super(decoder, self).__init__()
self.hidden_size = nh
self.decoder = Attentiondecoder(nh, nclass, dropout_p, max_length)
def forward(self, input, hidden, encoder_outputs):
return self.decoder(input, hidden, encoder_outputs)
def initHidden(self, batch_size):
result = Variable(torch.zeros(1, batch_size, self.hidden_size))
return result
# target_variable = target_txt_decode(b, length, text)
# decoder_input = torch.zeros(b).long().cuda() # 初始化decoder的开始,从0开始输出
# decoder_hidden = self.decoder.initHidden(b).cuda()
# if self.teach_forcing:
# # 教师强制:将目标label作为下一个输入
# for di in range(target_variable.shape[0]): # 最大字符串的长度
# decoder_output, decoder_hidden, decoder_attention = self.decoder(
# decoder_input, decoder_hidden, encoder_outputs)
# loss += criterion(decoder_output, target_variable[di]) # 每次预测一个字符
# decoder_input = target_variable[di] # Teacher forcing/前一次的输出
\ No newline at end of file
#CUDA_VISIBLE_DEVICES=0 nohup python main.py --experiment expr_basic --trainlist data/train_list.txt --vallist data/val_list.txt --cuda --adam --lr=0.001 > log.txt &
CUDA_VISIBLE_DEVICES=1 nohup python main.py --lang --experiment expr_basic_lang --trainlist data/train_list.txt --vallist data/val_list.txt --cuda --adam --lr=0.001 > log_lang.txt &
......@@ -5,6 +5,12 @@ import torch
import torch.nn as nn
from torch.autograd import Variable
import collections
from PIL import Image, ImageFilter
import math
import random
import numpy as np
import cv2
class strLabelConverterForAttention(object):
"""Convert between str and label.
......@@ -17,36 +23,43 @@ class strLabelConverterForAttention(object):
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, sep):
self.sep = sep
self.alphabet = alphabet.split(sep)
def __init__(self, alphabet):
self.alphabet = alphabet
self.dict = {}
self.dict['SOS'] = 0 # 开始
self.dict['EOS'] = 1 # 结束
self.dict['$'] = 2 # blank标识符
for i, item in enumerate(self.alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[item] = i
self.dict[item] = i + 3 # 从3开始编码
def encode(self, text):
"""Support batch or single str.
"""对target_label做编码和对齐
对target txt每个字符串的开始加上GO,最后加上EOS,并用最长的字符串做对齐
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
torch.IntTensor targets:max_length × batch_size
"""
if isinstance(text, str):
text = text.split(self.sep)
text = [self.dict[item] for item in text]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s.split(self.sep)) for s in text]
text = self.sep.join(text)
text, _ = self.encode(text)
return (torch.LongTensor(text), torch.LongTensor(length))
def decode(self, t, length):
text = [self.encode(s) for s in text] # 编码
max_length = max([len(x) for x in text]) # 对齐
nb = len(text)
targets = torch.ones(nb, max_length + 2) * 2 # use ‘blank’ for pading
for i in range(nb):
targets[i][0] = 0 # 开始
targets[i][1:len(text[i]) + 1] = text[i]
targets[i][len(text[i]) + 1] = 1
text = targets.transpose(0, 1).contiguous()
text = text.long()
return torch.LongTensor(text)
def decode(self, t):
"""Decode encoded texts back into strs.
Args:
......@@ -59,22 +72,9 @@ class strLabelConverterForAttention(object):
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
return ''.join([self.alphabet[i] for i in t])
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.LongTensor([l])))
index += l
return texts
texts = list(self.dict.keys())[list(self.dict.values()).index(t)]
return texts
class strLabelConverterForCTC(object):
"""Convert between str and label.
......@@ -213,3 +213,138 @@ def assureRatio(img):
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img
class halo():
'''
u:高斯分布的均值
sigma:方差
nums:在一张图片中随机添加几个光点
prob:使用halo的概率
'''
def __init__(self, nums, u=0, sigma=0.2, prob=0.5):
self.u = u # 均值μ
self.sig = math.sqrt(sigma) # 标准差δ
self.nums = nums
self.prob = prob
def create_kernel(self, maxh=32, maxw=50):
height_scope = [10, maxh] # 高度范围 随机生成高斯
weight_scope = [20, maxw] # 宽度范围
x = np.linspace(self.u - 3 * self.sig, self.u + 3 * self.sig, random.randint(*height_scope))
y = np.linspace(self.u - 3 * self.sig, self.u + 3 * self.sig, random.randint(*weight_scope))
Gauss_map = np.zeros((len(x), len(y)))
for i in range(len(x)):
for j in range(len(y)):
Gauss_map[i, j] = np.exp(-((x[i] - self.u) ** 2 + (y[j] - self.u) ** 2) / (2 * self.sig ** 2)) / (
math.sqrt(2 * math.pi) * self.sig)
return Gauss_map
def __call__(self, img):
if random.random() < self.prob:
Gauss_map = self.create_kernel(32, 60) # 初始化一个高斯核,32为高度方向的最大值,60为w方向
img1 = np.asarray(img)
img1.flags.writeable = True # 将数组改为读写模式
nums = random.randint(1, self.nums) # 随机生成nums个光点
img1 = img1.astype(np.float)
# print(nums)
for i in range(nums):
img_h, img_w = img1.shape
pointx = random.randint(0, img_h - 10) # 在原图中随机找一个点
pointy = random.randint(0, img_w - 10)
h, w = Gauss_map.shape # 判断是否超限
endx = pointx + h
endy = pointy + w
if pointx + h > img_h:
endx = img_h
Gauss_map = Gauss_map[1:img_h - pointx + 1, :]
if img_w < pointy + w:
endy = img_w
Gauss_map = Gauss_map[:, 1:img_w - pointy + 1]
# 加上不均匀光照
img1[pointx:endx, pointy:endy] = img1[pointx:endx, pointy:endy] + Gauss_map * 255.0
img1[img1 > 255.0] = 255.0 # 进行限幅,不然uint8会从0开始重新计数
img = img1
return Image.fromarray(np.uint8(img))
class MyGaussianBlur(ImageFilter.Filter):
name = "GaussianBlur"
def __init__(self, radius=2, bounds=None):
self.radius = radius
self.bounds = bounds
def filter(self, image):
if self.bounds:
clips = image.crop(self.bounds).gaussian_blur(self.radius)
image.paste(clips, self.bounds)
return image
else:
return image.gaussian_blur(self.radius)
class GBlur(object):
def __init__(self, radius=2, prob=0.5):
radius = random.randint(0, radius)
self.blur = MyGaussianBlur(radius=radius)
self.prob = prob
def __call__(self, img):
if random.random() < self.prob:
img = img.filter(self.blur)
return img
class RandomBrightness(object):
"""随机改变亮度
pil:pil格式的图片
"""
def __init__(self, prob=1.5):
self.prob = prob
def __call__(self, pil):
rgb = np.asarray(pil)
if random.random() < self.prob:
hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)
h, s, v = cv2.split(hsv)
adjust = random.choice([0.5, 0.7, 0.9, 1.2, 1.5, 1.7]) # 随机选择一个
# adjust = random.choice([1.2, 1.5, 1.7, 2.0]) # 随机选择一个
v = v * adjust
v = np.clip(v, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h, s, v))
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(np.uint8(rgb)).convert('L')
class randapply(object):
"""随机决定是否应用光晕、模糊或者二者都用
Args:
transforms (list or tuple): list of transformations
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册