Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
df28da76
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
df28da76
编写于
9月 16, 2016
作者:
Q
qingqing01
提交者:
emailweixu
9月 15, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
try to fix bug for CTCErrorEvaluator.cpp when batch_size > 1 (#82)
* try to fix bug for ctc_error_evaluator
上级
703cce35
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
26 addition
and
1 deletion
+26
-1
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
+4
-1
paddle/gserver/tests/test_Evaluator.cpp
paddle/gserver/tests/test_Evaluator.cpp
+22
-0
未找到文件。
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
浏览文件 @
df28da76
...
@@ -207,7 +207,7 @@ public:
...
@@ -207,7 +207,7 @@ public:
real
err
=
0
;
real
err
=
0
;
err
=
editDistance
(
err
=
editDistance
(
output
.
value
->
getData
()
+
output
.
value
->
getWidth
()
*
outputStarts
[
i
],
output
.
value
->
getData
()
+
output
.
value
->
getWidth
()
*
outputStarts
[
i
],
output
.
value
->
getHeight
()
,
output
.
value
->
getWidth
(),
output
Starts
[
i
+
1
]
-
outputStarts
[
i
]
,
output
.
value
->
getWidth
(),
label
.
ids
->
getData
()
+
labelStarts
[
i
],
label
.
ids
->
getData
()
+
labelStarts
[
i
],
labelStarts
[
i
+
1
]
-
labelStarts
[
i
]);
labelStarts
[
i
+
1
]
-
labelStarts
[
i
]);
...
@@ -224,6 +224,9 @@ public:
...
@@ -224,6 +224,9 @@ public:
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
arguments
.
push_back
(
nn
.
getLayer
(
name
)
->
getOutput
());
arguments
.
push_back
(
nn
.
getLayer
(
name
)
->
getOutput
());
}
}
}
virtual
void
updateSamplesNum
(
const
std
::
vector
<
Argument
>&
arguments
)
{
numSequences_
+=
arguments
[
1
].
getNumSequences
();
numSequences_
+=
arguments
[
1
].
getNumSequences
();
}
}
...
...
paddle/gserver/tests/test_Evaluator.cpp
浏览文件 @
df28da76
...
@@ -87,18 +87,31 @@ void testEvaluator(TestConfig testConf, string testEvaluatorName,
...
@@ -87,18 +87,31 @@ void testEvaluator(TestConfig testConf, string testEvaluatorName,
return
;
return
;
}
}
ICpuGpuVectorPtr
sequenceStartPositions
;
if
(
testConf
.
inputDefs
[
i
].
inputType
==
INPUT_SEQUENCE_DATA
||
testConf
.
inputDefs
[
i
].
inputType
==
INPUT_SEQUENCE_LABEL
)
{
if
(
!
sequenceStartPositions
)
{
generateSequenceStartPositions
(
batchSize
,
sequenceStartPositions
);
}
data
.
sequenceStartPositions
=
sequenceStartPositions
;
}
arguments
.
push_back
(
data
);
arguments
.
push_back
(
data
);
}
}
Evaluator
*
testEvaluator
=
Evaluator
::
create
(
testConf
.
evaluatorConfig
);
Evaluator
*
testEvaluator
=
Evaluator
::
create
(
testConf
.
evaluatorConfig
);
double
totalScore
=
0.0
;
double
totalScore
=
0.0
;
testEvaluator
->
start
();
totalScore
+=
testEvaluator
->
evalImp
(
arguments
);
totalScore
+=
testEvaluator
->
evalImp
(
arguments
);
testEvaluator
->
updateSamplesNum
(
arguments
);
testEvaluator
->
updateSamplesNum
(
arguments
);
testEvaluator
->
finish
();
LOG
(
INFO
)
<<
*
testEvaluator
;
LOG
(
INFO
)
<<
*
testEvaluator
;
double
totalScore2
=
0.0
;
double
totalScore2
=
0.0
;
if
(
testConf
.
testAccumulate
)
{
if
(
testConf
.
testAccumulate
)
{
testEvaluator
->
start
();
totalScore2
+=
testEvaluator
->
evalImp
(
arguments
);
totalScore2
+=
testEvaluator
->
evalImp
(
arguments
);
testEvaluator
->
finish
();
EXPECT_LE
(
fabs
(
totalScore
-
totalScore2
),
1.0e-5
);
EXPECT_LE
(
fabs
(
totalScore
-
totalScore2
),
1.0e-5
);
}
}
}
}
...
@@ -202,6 +215,15 @@ TEST(Evaluator, precision_recall) {
...
@@ -202,6 +215,15 @@ TEST(Evaluator, precision_recall) {
false
);
false
);
}
}
TEST
(
Evaluator
,
ctc_error_evaluator
)
{
TestConfig
config
;
config
.
evaluatorConfig
.
set_type
(
"ctc_edit_distance"
);
config
.
inputDefs
.
push_back
({
INPUT_SEQUENCE_DATA
,
"output"
,
32
});
config
.
inputDefs
.
push_back
({
INPUT_SEQUENCE_LABEL
,
"label"
,
1
});
testEvaluatorAll
(
config
,
"ctc_error_evaluator"
,
100
);
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
initMain
(
argc
,
argv
);
initMain
(
argc
,
argv
);
FLAGS_thread_local_rand_use_global_seed
=
true
;
FLAGS_thread_local_rand_use_global_seed
=
true
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录