Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
1b2ca6e6
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1b2ca6e6
编写于
8月 30, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
上级
c9e1077d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
0 addition
and
8 deletion
+0
-8
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+0
-2
tools/program.py
tools/program.py
+0
-4
tools/train.py
tools/train.py
+0
-2
未找到文件。
ppocr/modeling/transforms/tps.py
浏览文件 @
1b2ca6e6
...
@@ -326,6 +326,4 @@ class STN_ON(nn.Layer):
...
@@ -326,6 +326,4 @@ class STN_ON(nn.Layer):
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
#print("x:", np.sum(x.numpy()))
# print(x.shape)
return
x
return
x
tools/program.py
浏览文件 @
1b2ca6e6
...
@@ -215,9 +215,6 @@ def train(config,
...
@@ -215,9 +215,6 @@ def train(config,
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
state_dict
=
model
.
state_dict
()
# for key in state_dict:
# print(key)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
avg_loss
.
backward
()
avg_loss
.
backward
()
...
@@ -414,7 +411,6 @@ def preprocess(is_train=False):
...
@@ -414,7 +411,6 @@ def preprocess(is_train=False):
yaml
.
dump
(
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
print
(
"log has save in {}/train.log"
.
format
(
save_model_dir
))
else
:
else
:
log_file
=
None
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
...
...
tools/train.py
浏览文件 @
1b2ca6e6
...
@@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer):
...
@@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
character
=
getattr
(
post_process_class
,
'character'
)
print
(
"getattr character:"
,
character
)
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录