Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Opencv
提交
d636e112
O
Opencv
项目概览
Greenplum
/
Opencv
10 个月 前同步成功
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
Opencv
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d636e112
编写于
6月 27, 2012
作者:
A
Alexander Mordvintsev
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
removed ANN digits recognition
added deskew for SVN and KNearest recognition sample
上级
f2e78eed
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
118 addition
and
229 deletion
+118
-229
samples/python2/digits.py
samples/python2/digits.py
+118
-68
samples/python2/digits2.py
samples/python2/digits2.py
+0
-161
未找到文件。
samples/python2/digits.py
浏览文件 @
d636e112
'''
Neural network digit recognition sample.
SVN and KNearest digit recognition.
Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a SVN and KNearest classifiers on it and evaluates
their accuracy. Moment-based image deskew is used to improve
the recognition accuracy.
Usage:
digits.py
Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a neural network classifier on it and evaluates
its classification accuracy.
'''
import
numpy
as
np
import
cv2
from
common
import
mosaic
def
unroll_responses
(
responses
,
class_n
):
'''[1, 0, 2, ...] -> [[0, 1, 0], [1, 0, 0], [0, 0, 1], ...]'''
sample_n
=
len
(
responses
)
new_responses
=
np
.
zeros
((
sample_n
,
class_n
),
np
.
float32
)
new_responses
[
np
.
arange
(
sample_n
),
responses
]
=
1
return
new_responses
from
multiprocessing.pool
import
ThreadPool
from
common
import
clock
,
mosaic
SZ
=
20
# size of each digit is SZ x SZ
CLASS_N
=
10
digits_img
=
cv2
.
imread
(
'digits.png'
,
0
)
# prepare dataset
h
,
w
=
digits_img
.
shape
digits
=
[
np
.
hsplit
(
row
,
w
/
SZ
)
for
row
in
np
.
vsplit
(
digits_img
,
h
/
SZ
)]
digits
=
np
.
float32
(
digits
).
reshape
(
-
1
,
SZ
*
SZ
)
N
=
len
(
digits
)
labels
=
np
.
repeat
(
np
.
arange
(
CLASS_N
),
N
/
CLASS_N
)
# split it onto train and test subsets
shuffle
=
np
.
random
.
permutation
(
N
)
train_n
=
int
(
0.9
*
N
)
digits_train
,
digits_test
=
np
.
split
(
digits
[
shuffle
],
[
train_n
])
labels_train
,
labels_test
=
np
.
split
(
labels
[
shuffle
],
[
train_n
])
# train model
model
=
cv2
.
ANN_MLP
()
layer_sizes
=
np
.
int32
([
SZ
*
SZ
,
25
,
CLASS_N
])
model
.
create
(
layer_sizes
)
params
=
dict
(
term_crit
=
(
cv2
.
TERM_CRITERIA_COUNT
,
100
,
0.01
),
train_method
=
cv2
.
ANN_MLP_TRAIN_PARAMS_BACKPROP
,
bp_dw_scale
=
0.001
,
bp_moment_scale
=
0.0
)
print
'training...'
labels_train_unrolled
=
unroll_responses
(
labels_train
,
CLASS_N
)
model
.
train
(
digits_train
,
labels_train_unrolled
,
None
,
params
=
params
)
model
.
save
(
'dig_nn.dat'
)
model
.
load
(
'dig_nn.dat'
)
def
evaluate
(
model
,
samples
,
labels
):
'''Evaluates classifier preformance on a given labeled samples set.'''
ret
,
resp
=
model
.
predict
(
samples
)
resp
=
resp
.
argmax
(
-
1
)
error_mask
=
(
resp
==
labels
)
accuracy
=
error_mask
.
mean
()
return
accuracy
,
error_mask
# evaluate model
train_accuracy
,
_
=
evaluate
(
model
,
digits_train
,
labels_train
)
print
'train accuracy: '
,
train_accuracy
test_accuracy
,
test_error_mask
=
evaluate
(
model
,
digits_test
,
labels_test
)
print
'test accuracy: '
,
test_accuracy
# visualize test results
vis
=
[]
for
img
,
flag
in
zip
(
digits_test
,
test_error_mask
):
img
=
np
.
uint8
(
img
).
reshape
(
SZ
,
SZ
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
not
flag
:
img
[...,:
2
]
=
0
vis
.
append
(
img
)
vis
=
mosaic
(
25
,
vis
)
cv2
.
imshow
(
'test'
,
vis
)
cv2
.
waitKey
()
def
load_digits
(
fn
):
print
'loading "%s" ...'
%
fn
digits_img
=
cv2
.
imread
(
fn
,
0
)
h
,
w
=
digits_img
.
shape
digits
=
[
np
.
hsplit
(
row
,
w
/
SZ
)
for
row
in
np
.
vsplit
(
digits_img
,
h
/
SZ
)]
digits
=
np
.
array
(
digits
).
reshape
(
-
1
,
SZ
,
SZ
)
labels
=
np
.
repeat
(
np
.
arange
(
CLASS_N
),
len
(
digits
)
/
CLASS_N
)
return
digits
,
labels
def
deskew
(
img
):
m
=
cv2
.
moments
(
img
)
if
abs
(
m
[
'mu02'
])
<
1e-2
:
return
img
.
copy
()
skew
=
m
[
'mu11'
]
/
m
[
'mu02'
]
M
=
np
.
float32
([[
1
,
skew
,
-
0.5
*
SZ
*
skew
],
[
0
,
1
,
0
]])
img
=
cv2
.
warpAffine
(
img
,
M
,
(
SZ
,
SZ
),
flags
=
cv2
.
WARP_INVERSE_MAP
|
cv2
.
INTER_LINEAR
)
return
img
class
StatModel
(
object
):
def
load
(
self
,
fn
):
self
.
model
.
load
(
fn
)
def
save
(
self
,
fn
):
self
.
model
.
save
(
fn
)
class
KNearest
(
StatModel
):
def
__init__
(
self
,
k
=
3
):
self
.
k
=
k
self
.
model
=
cv2
.
KNearest
()
def
train
(
self
,
samples
,
responses
):
self
.
model
=
cv2
.
KNearest
()
self
.
model
.
train
(
samples
,
responses
)
def
predict
(
self
,
samples
):
retval
,
results
,
neigh_resp
,
dists
=
self
.
model
.
find_nearest
(
samples
,
self
.
k
)
return
results
.
ravel
()
class
SVM
(
StatModel
):
def
__init__
(
self
,
C
=
1
,
gamma
=
0.5
):
self
.
params
=
dict
(
kernel_type
=
cv2
.
SVM_RBF
,
svm_type
=
cv2
.
SVM_C_SVC
,
C
=
C
,
gamma
=
gamma
)
self
.
model
=
cv2
.
SVM
()
def
train
(
self
,
samples
,
responses
):
self
.
model
=
cv2
.
SVM
()
self
.
model
.
train
(
samples
,
responses
,
params
=
self
.
params
)
def
predict
(
self
,
samples
):
return
self
.
model
.
predict_all
(
samples
).
ravel
()
def
evaluate_model
(
model
,
digits
,
samples
,
labels
):
resp
=
model
.
predict
(
samples
)
err
=
(
labels
!=
resp
).
mean
()
print
'error: %.2f %%'
%
(
err
*
100
)
confusion
=
np
.
zeros
((
10
,
10
),
np
.
int32
)
for
i
,
j
in
zip
(
labels
,
resp
):
confusion
[
i
,
j
]
+=
1
print
'confusion matrix:'
print
confusion
print
vis
=
[]
for
img
,
flag
in
zip
(
digits
,
resp
==
labels
):
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
not
flag
:
img
[...,:
2
]
=
0
vis
.
append
(
img
)
return
mosaic
(
25
,
vis
)
if
__name__
==
'__main__'
:
print
__doc__
digits
,
labels
=
load_digits
(
'digits.png'
)
print
'preprocessing...'
# shuffle digits
rand
=
np
.
random
.
RandomState
(
12345
)
shuffle
=
rand
.
permutation
(
len
(
digits
))
digits
,
labels
=
digits
[
shuffle
],
labels
[
shuffle
]
digits2
=
map
(
deskew
,
digits
)
samples
=
np
.
float32
(
digits2
).
reshape
(
-
1
,
SZ
*
SZ
)
/
255.0
train_n
=
int
(
0.9
*
len
(
samples
))
cv2
.
imshow
(
'test set'
,
mosaic
(
25
,
digits
[
train_n
:]))
digits_train
,
digits_test
=
np
.
split
(
digits2
,
[
train_n
])
samples_train
,
samples_test
=
np
.
split
(
samples
,
[
train_n
])
labels_train
,
labels_test
=
np
.
split
(
labels
,
[
train_n
])
print
'training KNearest...'
model
=
KNearest
(
k
=
1
)
model
.
train
(
samples_train
,
labels_train
)
vis
=
evaluate_model
(
model
,
digits_test
,
samples_test
,
labels_test
)
cv2
.
imshow
(
'KNearest test'
,
vis
)
print
'training SVM...'
model
=
SVM
(
C
=
4.66
,
gamma
=
0.08
)
model
.
train
(
samples_train
,
labels_train
)
vis
=
evaluate_model
(
model
,
digits_test
,
samples_test
,
labels_test
)
cv2
.
imshow
(
'SVM test'
,
vis
)
cv2
.
waitKey
(
0
)
samples/python2/digits2.py
已删除
100644 → 0
浏览文件 @
f2e78eed
import
numpy
as
np
import
cv2
from
multiprocessing.pool
import
ThreadPool
SZ
=
20
# size of each digit is SZ x SZ
CLASS_N
=
10
def
load_base
(
fn
):
print
'loading "%s" ...'
%
fn
digits_img
=
cv2
.
imread
(
fn
,
0
)
h
,
w
=
digits_img
.
shape
digits
=
[
np
.
hsplit
(
row
,
w
/
SZ
)
for
row
in
np
.
vsplit
(
digits_img
,
h
/
SZ
)]
digits
=
np
.
array
(
digits
).
reshape
(
-
1
,
SZ
,
SZ
)
digits
=
np
.
float32
(
digits
).
reshape
(
-
1
,
SZ
*
SZ
)
/
255.0
labels
=
np
.
repeat
(
np
.
arange
(
CLASS_N
),
len
(
digits
)
/
CLASS_N
)
return
digits
,
labels
def
cross_validate
(
model_class
,
params
,
samples
,
labels
,
kfold
=
4
,
pool
=
None
):
n
=
len
(
samples
)
folds
=
np
.
array_split
(
np
.
arange
(
n
),
kfold
)
def
f
(
i
):
model
=
model_class
(
**
params
)
test_idx
=
folds
[
i
]
train_idx
=
list
(
folds
)
train_idx
.
pop
(
i
)
train_idx
=
np
.
hstack
(
train_idx
)
train_samples
,
train_labels
=
samples
[
train_idx
],
labels
[
train_idx
]
test_samples
,
test_labels
=
samples
[
test_idx
],
labels
[
test_idx
]
model
.
train
(
train_samples
,
train_labels
)
resp
=
model
.
predict
(
test_samples
)
score
=
(
resp
!=
test_labels
).
mean
()
print
"."
,
return
score
if
pool
is
None
:
scores
=
map
(
f
,
xrange
(
kfold
))
else
:
scores
=
pool
.
map
(
f
,
xrange
(
kfold
))
return
np
.
mean
(
scores
)
class
StatModel
(
object
):
def
load
(
self
,
fn
):
self
.
model
.
load
(
fn
)
def
save
(
self
,
fn
):
self
.
model
.
save
(
fn
)
class
KNearest
(
StatModel
):
def
__init__
(
self
,
k
=
3
):
self
.
k
=
k
@
staticmethod
def
adjust
(
samples
,
labels
):
print
'adjusting KNearest ...'
best_err
,
best_k
=
np
.
inf
,
-
1
for
k
in
xrange
(
1
,
11
):
err
=
cross_validate
(
KNearest
,
dict
(
k
=
k
),
samples
,
labels
)
if
err
<
best_err
:
best_err
,
best_k
=
err
,
k
print
'k = %d, error: %.2f %%'
%
(
k
,
err
*
100
)
best_params
=
dict
(
k
=
best_k
)
print
'best params:'
,
best_params
return
best_params
def
train
(
self
,
samples
,
responses
):
self
.
model
=
cv2
.
KNearest
()
self
.
model
.
train
(
samples
,
responses
)
def
predict
(
self
,
samples
):
retval
,
results
,
neigh_resp
,
dists
=
self
.
model
.
find_nearest
(
samples
,
self
.
k
)
return
results
.
ravel
()
class
SVM
(
StatModel
):
def
__init__
(
self
,
C
=
1
,
gamma
=
0.5
):
self
.
params
=
dict
(
kernel_type
=
cv2
.
SVM_RBF
,
svm_type
=
cv2
.
SVM_C_SVC
,
C
=
C
,
gamma
=
gamma
)
@
staticmethod
def
adjust
(
samples
,
labels
):
Cs
=
np
.
logspace
(
0
,
5
,
10
,
base
=
2
)
gammas
=
np
.
logspace
(
-
7
,
-
2
,
10
,
base
=
2
)
scores
=
np
.
zeros
((
len
(
Cs
),
len
(
gammas
)))
scores
[:]
=
np
.
nan
print
'adjusting SVM (may take a long time) ...'
def
f
(
job
):
i
,
j
=
job
params
=
dict
(
C
=
Cs
[
i
],
gamma
=
gammas
[
j
])
score
=
cross_validate
(
SVM
,
params
,
samples
,
labels
)
scores
[
i
,
j
]
=
score
nready
=
np
.
isfinite
(
scores
).
sum
()
print
'%d / %d (best error: %.2f %%, last: %.2f %%)'
%
(
nready
,
scores
.
size
,
np
.
nanmin
(
scores
)
*
100
,
score
*
100
)
pool
=
ThreadPool
(
processes
=
cv2
.
getNumberOfCPUs
())
pool
.
map
(
f
,
np
.
ndindex
(
*
scores
.
shape
))
print
scores
i
,
j
=
np
.
unravel_index
(
scores
.
argmin
(),
scores
.
shape
)
best_params
=
dict
(
C
=
Cs
[
i
],
gamma
=
gammas
[
j
])
print
'best params:'
,
best_params
print
'best error: %.2f %%'
%
(
scores
.
min
()
*
100
)
return
best_params
def
train
(
self
,
samples
,
responses
):
self
.
model
=
cv2
.
SVM
()
self
.
model
.
train
(
samples
,
responses
,
params
=
self
.
params
)
def
predict
(
self
,
samples
):
return
self
.
model
.
predict_all
(
samples
).
ravel
()
def
main_adjustSVM
(
samples
,
labels
):
params
=
SVM
.
adjust
(
samples
,
labels
)
print
'training SVM on all samples ...'
model
=
SVN
(
**
params
)
model
.
train
(
samples
,
labels
)
print
'saving "digits_svm.dat" ...'
model
.
save
(
'digits_svm.dat'
)
def
main_adjustKNearest
(
samples
,
labels
):
params
=
KNearest
.
adjust
(
samples
,
labels
)
def
main_showSVM
(
samples
,
labels
):
from
common
import
mosaic
train_n
=
int
(
0.9
*
len
(
samples
))
digits_train
,
digits_test
=
np
.
split
(
samples
[
shuffle
],
[
train_n
])
labels_train
,
labels_test
=
np
.
split
(
labels
[
shuffle
],
[
train_n
])
print
'training SVM ...'
model
=
SVM
(
C
=
2.16
,
gamma
=
0.0536
)
model
.
train
(
digits_train
,
labels_train
)
train_err
=
(
model
.
predict
(
digits_train
)
!=
labels_train
).
mean
()
resp_test
=
model
.
predict
(
digits_test
)
test_err
=
(
resp_test
!=
labels_test
).
mean
()
print
'train errors: %.2f %%'
%
(
train_err
*
100
)
print
'test errors: %.2f %%'
%
(
test_err
*
100
)
# visualize test results
vis
=
[]
for
img
,
flag
in
zip
(
digits_test
,
resp_test
==
labels_test
):
img
=
np
.
uint8
(
img
*
255
).
reshape
(
SZ
,
SZ
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
not
flag
:
img
[...,:
2
]
=
0
vis
.
append
(
img
)
vis
=
mosaic
(
25
,
vis
)
cv2
.
imshow
(
'test'
,
vis
)
cv2
.
waitKey
()
if
__name__
==
'__main__'
:
samples
,
labels
=
load_base
(
'digits.png'
)
shuffle
=
np
.
random
.
permutation
(
len
(
samples
))
samples
,
labels
=
samples
[
shuffle
],
labels
[
shuffle
]
#main_adjustSVM(samples, labels)
#main_adjustKNearest(samples, labels)
main_showSVM
(
samples
,
labels
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录