提交 d636e112 编写于 作者: A Alexander Mordvintsev

removed ANN digits recognition

added deskew for SVN and KNearest recognition sample
上级 f2e78eed
''' '''
Neural network digit recognition sample. SVN and KNearest digit recognition.
Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a SVN and KNearest classifiers on it and evaluates
their accuracy. Moment-based image deskew is used to improve
the recognition accuracy.
Usage: Usage:
digits.py digits.py
Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a neural network classifier on it and evaluates
its classification accuracy.
''' '''
import numpy as np import numpy as np
import cv2 import cv2
from common import mosaic from multiprocessing.pool import ThreadPool
from common import clock, mosaic
def unroll_responses(responses, class_n):
'''[1, 0, 2, ...] -> [[0, 1, 0], [1, 0, 0], [0, 0, 1], ...]'''
sample_n = len(responses)
new_responses = np.zeros((sample_n, class_n), np.float32)
new_responses[np.arange(sample_n), responses] = 1
return new_responses
SZ = 20 # size of each digit is SZ x SZ SZ = 20 # size of each digit is SZ x SZ
CLASS_N = 10 CLASS_N = 10
digits_img = cv2.imread('digits.png', 0)
def load_digits(fn):
# prepare dataset print 'loading "%s" ...' % fn
h, w = digits_img.shape digits_img = cv2.imread(fn, 0)
digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)] h, w = digits_img.shape
digits = np.float32(digits).reshape(-1, SZ*SZ) digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)]
N = len(digits) digits = np.array(digits).reshape(-1, SZ, SZ)
labels = np.repeat(np.arange(CLASS_N), N/CLASS_N) labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
return digits, labels
# split it onto train and test subsets
shuffle = np.random.permutation(N) def deskew(img):
train_n = int(0.9*N) m = cv2.moments(img)
digits_train, digits_test = np.split(digits[shuffle], [train_n]) if abs(m['mu02']) < 1e-2:
labels_train, labels_test = np.split(labels[shuffle], [train_n]) return img.copy()
skew = m['mu11']/m['mu02']
# train model M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
model = cv2.ANN_MLP() img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
layer_sizes = np.int32([SZ*SZ, 25, CLASS_N]) return img
model.create(layer_sizes)
params = dict( term_crit = (cv2.TERM_CRITERIA_COUNT, 100, 0.01), class StatModel(object):
train_method = cv2.ANN_MLP_TRAIN_PARAMS_BACKPROP, def load(self, fn):
bp_dw_scale = 0.001, self.model.load(fn)
bp_moment_scale = 0.0 ) def save(self, fn):
print 'training...' self.model.save(fn)
labels_train_unrolled = unroll_responses(labels_train, CLASS_N)
model.train(digits_train, labels_train_unrolled, None, params=params) class KNearest(StatModel):
model.save('dig_nn.dat') def __init__(self, k = 3):
model.load('dig_nn.dat') self.k = k
self.model = cv2.KNearest()
def evaluate(model, samples, labels):
'''Evaluates classifier preformance on a given labeled samples set.''' def train(self, samples, responses):
ret, resp = model.predict(samples) self.model = cv2.KNearest()
resp = resp.argmax(-1) self.model.train(samples, responses)
error_mask = (resp == labels)
accuracy = error_mask.mean() def predict(self, samples):
return accuracy, error_mask retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
return results.ravel()
# evaluate model
train_accuracy, _ = evaluate(model, digits_train, labels_train) class SVM(StatModel):
print 'train accuracy: ', train_accuracy def __init__(self, C = 1, gamma = 0.5):
test_accuracy, test_error_mask = evaluate(model, digits_test, labels_test) self.params = dict( kernel_type = cv2.SVM_RBF,
print 'test accuracy: ', test_accuracy svm_type = cv2.SVM_C_SVC,
C = C,
# visualize test results gamma = gamma )
vis = [] self.model = cv2.SVM()
for img, flag in zip(digits_test, test_error_mask):
img = np.uint8(img).reshape(SZ, SZ) def train(self, samples, responses):
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) self.model = cv2.SVM()
if not flag: self.model.train(samples, responses, params = self.params)
img[...,:2] = 0
vis.append(img) def predict(self, samples):
vis = mosaic(25, vis) return self.model.predict_all(samples).ravel()
cv2.imshow('test', vis)
cv2.waitKey()
def evaluate_model(model, digits, samples, labels):
resp = model.predict(samples)
err = (labels != resp).mean()
print 'error: %.2f %%' % (err*100)
confusion = np.zeros((10, 10), np.int32)
for i, j in zip(labels, resp):
confusion[i, j] += 1
print 'confusion matrix:'
print confusion
print
vis = []
for img, flag in zip(digits, resp == labels):
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if not flag:
img[...,:2] = 0
vis.append(img)
return mosaic(25, vis)
if __name__ == '__main__':
print __doc__
digits, labels = load_digits('digits.png')
print 'preprocessing...'
# shuffle digits
rand = np.random.RandomState(12345)
shuffle = rand.permutation(len(digits))
digits, labels = digits[shuffle], labels[shuffle]
digits2 = map(deskew, digits)
samples = np.float32(digits2).reshape(-1, SZ*SZ) / 255.0
train_n = int(0.9*len(samples))
cv2.imshow('test set', mosaic(25, digits[train_n:]))
digits_train, digits_test = np.split(digits2, [train_n])
samples_train, samples_test = np.split(samples, [train_n])
labels_train, labels_test = np.split(labels, [train_n])
print 'training KNearest...'
model = KNearest(k=1)
model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('KNearest test', vis)
print 'training SVM...'
model = SVM(C=4.66, gamma=0.08)
model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('SVM test', vis)
cv2.waitKey(0)
import numpy as np
import cv2
from multiprocessing.pool import ThreadPool
SZ = 20 # size of each digit is SZ x SZ
CLASS_N = 10
def load_base(fn):
print 'loading "%s" ...' % fn
digits_img = cv2.imread(fn, 0)
h, w = digits_img.shape
digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)]
digits = np.array(digits).reshape(-1, SZ, SZ)
digits = np.float32(digits).reshape(-1, SZ*SZ) / 255.0
labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
return digits, labels
def cross_validate(model_class, params, samples, labels, kfold = 4, pool = None):
n = len(samples)
folds = np.array_split(np.arange(n), kfold)
def f(i):
model = model_class(**params)
test_idx = folds[i]
train_idx = list(folds)
train_idx.pop(i)
train_idx = np.hstack(train_idx)
train_samples, train_labels = samples[train_idx], labels[train_idx]
test_samples, test_labels = samples[test_idx], labels[test_idx]
model.train(train_samples, train_labels)
resp = model.predict(test_samples)
score = (resp != test_labels).mean()
print ".",
return score
if pool is None:
scores = map(f, xrange(kfold))
else:
scores = pool.map(f, xrange(kfold))
return np.mean(scores)
class StatModel(object):
def load(self, fn):
self.model.load(fn)
def save(self, fn):
self.model.save(fn)
class KNearest(StatModel):
def __init__(self, k = 3):
self.k = k
@staticmethod
def adjust(samples, labels):
print 'adjusting KNearest ...'
best_err, best_k = np.inf, -1
for k in xrange(1, 11):
err = cross_validate(KNearest, dict(k=k), samples, labels)
if err < best_err:
best_err, best_k = err, k
print 'k = %d, error: %.2f %%' % (k, err*100)
best_params = dict(k=best_k)
print 'best params:', best_params
return best_params
def train(self, samples, responses):
self.model = cv2.KNearest()
self.model.train(samples, responses)
def predict(self, samples):
retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
return results.ravel()
class SVM(StatModel):
def __init__(self, C = 1, gamma = 0.5):
self.params = dict( kernel_type = cv2.SVM_RBF,
svm_type = cv2.SVM_C_SVC,
C = C,
gamma = gamma )
@staticmethod
def adjust(samples, labels):
Cs = np.logspace(0, 5, 10, base=2)
gammas = np.logspace(-7, -2, 10, base=2)
scores = np.zeros((len(Cs), len(gammas)))
scores[:] = np.nan
print 'adjusting SVM (may take a long time) ...'
def f(job):
i, j = job
params = dict(C = Cs[i], gamma=gammas[j])
score = cross_validate(SVM, params, samples, labels)
scores[i, j] = score
nready = np.isfinite(scores).sum()
print '%d / %d (best error: %.2f %%, last: %.2f %%)' % (nready, scores.size, np.nanmin(scores)*100, score*100)
pool = ThreadPool(processes=cv2.getNumberOfCPUs())
pool.map(f, np.ndindex(*scores.shape))
print scores
i, j = np.unravel_index(scores.argmin(), scores.shape)
best_params = dict(C = Cs[i], gamma=gammas[j])
print 'best params:', best_params
print 'best error: %.2f %%' % (scores.min()*100)
return best_params
def train(self, samples, responses):
self.model = cv2.SVM()
self.model.train(samples, responses, params = self.params)
def predict(self, samples):
return self.model.predict_all(samples).ravel()
def main_adjustSVM(samples, labels):
params = SVM.adjust(samples, labels)
print 'training SVM on all samples ...'
model = SVN(**params)
model.train(samples, labels)
print 'saving "digits_svm.dat" ...'
model.save('digits_svm.dat')
def main_adjustKNearest(samples, labels):
params = KNearest.adjust(samples, labels)
def main_showSVM(samples, labels):
from common import mosaic
train_n = int(0.9*len(samples))
digits_train, digits_test = np.split(samples[shuffle], [train_n])
labels_train, labels_test = np.split(labels[shuffle], [train_n])
print 'training SVM ...'
model = SVM(C=2.16, gamma=0.0536)
model.train(digits_train, labels_train)
train_err = (model.predict(digits_train) != labels_train).mean()
resp_test = model.predict(digits_test)
test_err = (resp_test != labels_test).mean()
print 'train errors: %.2f %%' % (train_err*100)
print 'test errors: %.2f %%' % (test_err*100)
# visualize test results
vis = []
for img, flag in zip(digits_test, resp_test == labels_test):
img = np.uint8(img*255).reshape(SZ, SZ)
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if not flag:
img[...,:2] = 0
vis.append(img)
vis = mosaic(25, vis)
cv2.imshow('test', vis)
cv2.waitKey()
if __name__ == '__main__':
samples, labels = load_base('digits.png')
shuffle = np.random.permutation(len(samples))
samples, labels = samples[shuffle], labels[shuffle]
#main_adjustSVM(samples, labels)
#main_adjustKNearest(samples, labels)
main_showSVM(samples, labels)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册