Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0a0c55d2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a0c55d2
编写于
11月 09, 2016
作者:
W
wangyanfei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more friendly test options
上级
c6a0298e
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
53 addition
and
32 deletion
+53
-32
paddle/trainer/Tester.cpp
paddle/trainer/Tester.cpp
+13
-5
paddle/trainer/Tester.h
paddle/trainer/Tester.h
+1
-1
paddle/trainer/TesterConfig.h
paddle/trainer/TesterConfig.h
+8
-3
paddle/trainer/Trainer.cpp
paddle/trainer/Trainer.cpp
+31
-23
未找到文件。
paddle/trainer/Tester.cpp
浏览文件 @
0a0c55d2
...
@@ -90,13 +90,20 @@ void Tester::testOneDataBatch(
...
@@ -90,13 +90,20 @@ void Tester::testOneDataBatch(
testContext_
.
numSamples
+=
dataBatch
.
getSize
();
testContext_
.
numSamples
+=
dataBatch
.
getSize
();
}
}
void
Tester
::
testOnePeriod
()
{
void
Tester
::
testOnePeriod
(
bool
finishPass
)
{
DataBatch
dataBatch
;
DataBatch
dataBatch
;
int64_t
batchSize
=
config_
->
getOptConfig
().
batch_size
();
int64_t
batchSize
=
config_
->
getOptConfig
().
batch_size
();
bool
testAllData
=
bool
testAllData
=
intconfig_
->
testPeriod
==
0
||
intconfig_
->
testAllDataInOnePeriod
;
(
!
finishPass
&&
!
intconfig_
->
testBatchesWhileTraining
)
||
int
batches
=
(
finishPass
&&
!
intconfig_
->
testBatchesWhileEnd
);
testAllData
?
std
::
numeric_limits
<
int
>::
max
()
:
intconfig_
->
testPeriod
;
int
batches
;
if
(
testAllData
)
{
batches
=
std
::
numeric_limits
<
int
>::
max
();
}
else
{
batches
=
finishPass
?
intconfig_
->
testBatchesWhileEnd
:
intconfig_
->
testBatchesWhileTraining
;
}
std
::
vector
<
Argument
>
outArgs
;
std
::
vector
<
Argument
>
outArgs
;
...
@@ -108,7 +115,8 @@ void Tester::testOnePeriod() {
...
@@ -108,7 +115,8 @@ void Tester::testOnePeriod() {
if
(
intconfig_
->
prevBatchState
)
{
if
(
intconfig_
->
prevBatchState
)
{
gradientMachine_
->
resetState
();
gradientMachine_
->
resetState
();
}
}
if
(
testAllData
)
{
if
((
!
finishPass
&&
!
intconfig_
->
testBatchesWhileTraining
)
||
(
finishPass
&&
!
intconfig_
->
testBatchesWhileEnd
))
{
break
;
break
;
}
else
{
}
else
{
num
=
testDataProvider_
->
getNextBatch
(
batchSize
,
&
dataBatch
);
num
=
testDataProvider_
->
getNextBatch
(
batchSize
,
&
dataBatch
);
...
...
paddle/trainer/Tester.h
浏览文件 @
0a0c55d2
...
@@ -67,7 +67,7 @@ public:
...
@@ -67,7 +67,7 @@ public:
* It is convenience to test small set of data when test data set is large and
* It is convenience to test small set of data when test data set is large and
* is training at same time.
* is training at same time.
*/
*/
void
testOnePeriod
();
void
testOnePeriod
(
bool
finishPass
=
true
);
void
startTestPeriod
();
void
startTestPeriod
();
void
finishTestPeriod
();
void
finishTestPeriod
();
void
testOneDataBatch
(
const
DataBatch
&
dataBatch
,
void
testOneDataBatch
(
const
DataBatch
&
dataBatch
,
...
...
paddle/trainer/TesterConfig.h
浏览文件 @
0a0c55d2
...
@@ -38,12 +38,17 @@ struct TesterConfig {
...
@@ -38,12 +38,17 @@ struct TesterConfig {
/**
/**
* indicate test period
* indicate test period
*/
*/
int
testPeriod
;
int
testPeriod
WhileTraining
;
/**
/**
* indicate
whether testing data in one period
* indicate
how many batches are used for testing under training
*/
*/
bool
testAllDataInOnePeriod
;
bool
testBatchesWhileTraining
;
/**
* indicate how many batches are used for testing at pass end
*/
bool
testBatchesWhileEnd
;
/**
/**
* indicate whether to save previous batch state
* indicate whether to save previous batch state
...
...
paddle/trainer/Trainer.cpp
浏览文件 @
0a0c55d2
...
@@ -40,31 +40,28 @@ limitations under the License. */
...
@@ -40,31 +40,28 @@ limitations under the License. */
#include "TrainerConfigHelper.h"
#include "TrainerConfigHelper.h"
P_DEFINE_string
(
config
,
""
,
"Trainer config file"
);
P_DEFINE_string
(
config
,
""
,
"Trainer config file"
);
P_DEFINE_int32
(
test_period
,
0
,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data"
);
P_DEFINE_int32
(
test_batches_while_training
,
0
,
P_DEFINE_int32
(
test_period
,
0
,
"This option was deprecated, use test_period_while_training "
" instead. "
);
P_DEFINE_int32
(
test_period_while_training
,
0
,
"Run test every so many train batches."
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If not 0, test log_period batches."
" If 0, test nothing."
);
P_DEFINE_int32
(
test_batches_while_training
,
1000
,
"test test_batches_while_training batches if test_period != 0."
" If 0, test on all test data"
);
" If 0, test on all test data"
);
P_DEFINE_int32
(
test_batches_while_end
,
0
,
P_DEFINE_int32
(
test_batches_while_end
,
0
,
"Run test every so many train batches."
"test test_batches_while_end batches at pass end."
" 0 for testing after each pass."
" Always run test at pass end."
" If not 0, test log_period batches."
" If not 0, test test_batches_while_end batches."
" If 0, test on all test data"
);
" If 0, test on all test data."
);
P_DEFINE_bool
(
test_all_data_in_one_period
,
false
,
"This option was deprecated, use test_batches_while_training "
"and test_batches_while_end instead"
);
P_DEFINE_bool
(
local
,
true
,
"Train in local mode or not"
);
P_DEFINE_bool
(
local
,
true
,
"Train in local mode or not"
);
P_DEFINE_bool
(
test_all_data_in_one_period
,
false
,
"true will test all data in one test peroid."
"Otherwise test (batch_size * log_peroid) data in one test period."
);
P_DEFINE_int32
(
average_test_period
,
0
,
P_DEFINE_int32
(
average_test_period
,
0
,
"Do test on average parameter every so"
"Do test on average parameter every so"
" many batches. MUST be devided by FLAGS_log_period."
" many batches. MUST be devided by FLAGS_log_period."
...
@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
...
@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
FOR_TIMING
(
globalStat
.
reset
());
FOR_TIMING
(
globalStat
.
reset
());
}
}
if
(
testDataProvider_
&&
FLAGS_test_period
>
0
&&
if
(
testDataProvider_
&&
FLAGS_test_period
_while_training
>
0
&&
trainPassContext_
.
batchId
%
FLAGS_test_period
==
0
)
{
trainPassContext_
.
batchId
%
FLAGS_test_period
_while_training
==
0
)
{
tester_
->
testOnePeriod
();
tester_
->
testOnePeriod
(
false
);
}
}
if
(
FLAGS_saving_period_by_batches
>
0
&&
if
(
FLAGS_saving_period_by_batches
>
0
&&
...
@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
...
@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
0
==
FLAGS_trainer_id
)
{
0
==
FLAGS_trainer_id
)
{
trainerInternal_
.
getParameterUpdater
()
->
catchUpWith
();
trainerInternal_
.
getParameterUpdater
()
->
catchUpWith
();
if
(
testDataProvider_
)
{
if
(
testDataProvider_
)
{
tester_
->
testOnePeriod
();
tester_
->
testOnePeriod
(
false
);
}
}
paramUtil_
->
saveParametersOnePass
(
paramUtil_
->
saveParametersOnePass
(
trainPassContext_
.
passId
,
trainPassContext_
.
passInnerId
);
trainPassContext_
.
passId
,
trainPassContext_
.
passInnerId
);
...
@@ -636,8 +633,19 @@ void Trainer::test() {
...
@@ -636,8 +633,19 @@ void Trainer::test() {
std
::
unique_ptr
<
TesterConfig
>
Trainer
::
createTesterConfig
()
{
std
::
unique_ptr
<
TesterConfig
>
Trainer
::
createTesterConfig
()
{
TesterConfig
*
conf
=
new
TesterConfig
;
TesterConfig
*
conf
=
new
TesterConfig
;
conf
->
testPeriod
=
FLAGS_test_period
;
if
(
FLAGS_test_period
)
{
conf
->
testAllDataInOnePeriod
=
FLAGS_test_all_data_in_one_period
;
LOG
(
WARNING
)
<<
"--test_period was deprecated, use --test_period_while_training"
<<
"--test_batches_while_training --test_batches_while_end instead."
;
}
if
(
FLAGS_test_all_data_in_one_period
)
{
LOG
(
WARNING
)
<<
"--test_all_data_in_one_period was deprecated, use"
<<
" --test_batches_while_training and --test_batches_while_end instead"
;
}
conf
->
testPeriodWhileTraining
=
FLAGS_test_period_while_training
;
conf
->
testBatchesWhileTraining
=
FLAGS_test_batches_while_training
;
conf
->
testBatchesWhileEnd
=
FLAGS_test_batches_while_end
;
conf
->
prevBatchState
=
FLAGS_prev_batch_state
;
conf
->
prevBatchState
=
FLAGS_prev_batch_state
;
conf
->
logPeriod
=
FLAGS_log_period
;
conf
->
logPeriod
=
FLAGS_log_period
;
conf
->
loadsaveParametersInPserver
=
FLAGS_loadsave_parameters_in_pserver
;
conf
->
loadsaveParametersInPserver
=
FLAGS_loadsave_parameters_in_pserver
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录