Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9af384f1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9af384f1
编写于
3月 12, 2021
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
try to fix imperative orc unitest error; test=develop
上级
95cceb2d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
17 addition
and
17 deletion
+17
-17
python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py
...id/tests/unittests/test_imperative_ocr_attention_model.py
+17
-17
未找到文件。
python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py
浏览文件 @
9af384f1
...
...
@@ -29,19 +29,19 @@ class Config(object):
config for training
'''
# encoder rnn hidden_size
encoder_size
=
16
encoder_size
=
8
# decoder size for decoder stage
decoder_size
=
16
decoder_size
=
8
# size for word embedding
word_vector_dim
=
16
word_vector_dim
=
8
# max length for label padding
max_length
=
5
max_length
=
3
# optimizer setting
LR
=
1.0
learning_rate_decay
=
None
# batch size to train
batch_size
=
8
batch_size
=
2
# class number to classify
num_classes
=
64
...
...
@@ -55,7 +55,7 @@ class Config(object):
TRAIN_LIST_FILE_NAME
=
"train.list"
# data shape for input image
DATA_SHAPE
=
[
1
,
48
,
38
4
]
DATA_SHAPE
=
[
1
,
16
,
6
4
]
class
ConvBNPool
(
fluid
.
dygraph
.
Layer
):
...
...
@@ -124,13 +124,13 @@ class OCRConv(fluid.dygraph.Layer):
def
__init__
(
self
,
is_test
=
False
,
use_cudnn
=
True
):
super
(
OCRConv
,
self
).
__init__
()
self
.
conv_bn_pool_1
=
ConvBNPool
(
2
,
[
16
,
16
],
[
1
,
16
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
2
,
[
8
,
8
],
[
1
,
8
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
conv_bn_pool_2
=
ConvBNPool
(
2
,
[
32
,
32
],
[
16
,
32
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
2
,
[
8
,
8
],
[
8
,
8
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
conv_bn_pool_3
=
ConvBNPool
(
2
,
[
64
,
64
],
[
32
,
64
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
2
,
[
8
,
8
],
[
8
,
8
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
conv_bn_pool_4
=
ConvBNPool
(
2
,
[
1
28
,
128
],
[
64
,
128
],
2
,
[
1
6
,
16
],
[
8
,
16
],
is_test
=
is_test
,
pool
=
False
,
use_cudnn
=
use_cudnn
)
...
...
@@ -212,9 +212,9 @@ class EncoderNet(fluid.dygraph.Layer):
self
.
ocr_convs
=
OCRConv
(
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
fc_1_layer
=
Linear
(
768
,
rnn_hidden_size
*
3
,
param_attr
=
para_attr
,
bias_attr
=
False
)
32
,
rnn_hidden_size
*
3
,
param_attr
=
para_attr
,
bias_attr
=
False
)
self
.
fc_2_layer
=
Linear
(
768
,
rnn_hidden_size
*
3
,
param_attr
=
para_attr
,
bias_attr
=
False
)
32
,
rnn_hidden_size
*
3
,
param_attr
=
para_attr
,
bias_attr
=
False
)
self
.
gru_forward_layer
=
DynamicGRU
(
size
=
rnn_hidden_size
,
h_0
=
h_0
,
...
...
@@ -241,10 +241,9 @@ class EncoderNet(fluid.dygraph.Layer):
transpose_conv_features
=
fluid
.
layers
.
transpose
(
conv_features
,
perm
=
[
0
,
3
,
1
,
2
])
sliced_feature
=
fluid
.
layers
.
reshape
(
transpose_conv_features
,
[
-
1
,
4
8
,
transpose_conv_features
.
shape
[
2
]
*
-
1
,
8
,
transpose_conv_features
.
shape
[
2
]
*
transpose_conv_features
.
shape
[
3
]
],
inplace
=
False
)
...
...
@@ -376,9 +375,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
seed
=
90
epoch_num
=
1
if
core
.
is_compiled_with_cuda
():
batch_num
=
6
batch_num
=
3
else
:
batch_num
=
4
batch_num
=
2
np
.
random
.
seed
=
seed
image_np
=
np
.
random
.
randn
(
Config
.
batch_size
,
Config
.
DATA_SHAPE
[
0
],
Config
.
DATA_SHAPE
[
1
],
...
...
@@ -536,8 +535,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
self
.
assertTrue
(
np
.
array_equal
(
value
,
dy_param_init_value
[
key
]))
for
key
,
value
in
six
.
iteritems
(
static_param_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
]))
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
]
,
rtol
=
1e-05
))
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录