Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Cloud IDE
CnOCR
提交
3e2cc77b
CnOCR
项目概览
Cloud IDE
/
CnOCR
落后 Fork 源项目 19 个版本
Fork自
breezedeus / CnOCR
通知
47
Star
3
Fork
30
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
CnOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
3e2cc77b
编写于
3月 27, 2020
作者:
B
breezedeus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lr and optimizer params
上级
65f55c97
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
24 deletion
+24
-24
scripts/cnocr_train.py
scripts/cnocr_train.py
+24
-24
未找到文件。
scripts/cnocr_train.py
浏览文件 @
3e2cc77b
...
...
@@ -80,7 +80,13 @@ def parse_args():
default
=
2
,
)
parser
.
add_argument
(
"--gpu"
,
help
=
"Number of GPUs for training [Default 0]"
,
type
=
int
"--gpu"
,
help
=
"Number of GPUs for training [Default 0]"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--optimizer"
,
help
=
"optimizer for training [Default: Adam]"
,
type
=
str
,
default
=
'Adam'
,
)
parser
.
add_argument
(
'--epoch'
,
type
=
int
,
default
=
20
,
help
=
'train epochs [Default: 20]'
...
...
@@ -90,6 +96,12 @@ def parse_args():
type
=
int
,
help
=
'load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]'
,
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
None
,
help
=
'learning rate [Default: None, means lr from hp will be used]'
,
)
parser
.
add_argument
(
"--prefix"
,
help
=
"Checkpoint prefix [Default '{}']"
.
format
(
default_model_prefix
),
...
...
@@ -98,36 +110,16 @@ def parse_args():
parser
.
add_argument
(
"--loss"
,
help
=
"'ctc' or 'warpctc' loss [Default 'ctc']"
,
default
=
'ctc'
)
parser
.
add_argument
(
"--num_proc"
,
help
=
"Number CAPTCHA generating processes [Default 4]"
,
type
=
int
,
default
=
4
,
)
parser
.
add_argument
(
"--font_path"
,
help
=
"Path to ttf font file or directory containing ttf files"
)
return
parser
.
parse_args
()
def
get_fonts
(
path
):
fonts
=
list
()
if
os
.
path
.
isdir
(
path
):
for
filename
in
os
.
listdir
(
path
):
if
filename
.
endswith
(
'.ttf'
)
or
filename
.
endswith
(
'.ttc'
):
fonts
.
append
(
os
.
path
.
join
(
path
,
filename
))
else
:
fonts
.
append
(
path
)
return
fonts
def
run_cn_ocr
(
args
):
def
train_cnocr
(
args
):
head
=
'%(asctime)-15s %(message)s'
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
head
)
args
.
prefix
=
'{}-{}'
.
format
(
args
.
prefix
,
args
.
model_name
)
hp
=
CnHyperparams
()
hp
.
_num_epoch
=
args
.
epoch
hp
=
_update_hp
(
hp
,
args
)
network
,
hp
=
gen_network
(
args
.
model_name
,
hp
)
metrics
=
CtcMetrics
(
hp
.
seq_length
)
...
...
@@ -147,6 +139,14 @@ def run_cn_ocr(args):
)
def
_update_hp
(
hp
,
args
):
hp
.
_num_epoch
=
args
.
epoch
hp
.
optimizer
=
args
.
optimizer
if
args
.
lr
is
not
None
:
hp
.
_learning_rate
=
args
.
lr
return
hp
def
_gen_iters
(
hp
,
train_fp_prefix
,
val_fp_prefix
,
use_train_image_aug
):
height
,
width
=
hp
.
img_height
,
hp
.
img_width
augs
=
None
...
...
@@ -192,4 +192,4 @@ def _gen_iters(hp, train_fp_prefix, val_fp_prefix, use_train_image_aug):
if
__name__
==
'__main__'
:
args
=
parse_args
()
run_cn_
ocr
(
args
)
train_cn
ocr
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录