Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0a0c55d2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a0c55d2
编写于
11月 09, 2016
作者:
W
wangyanfei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more friendly test options
上级
c6a0298e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
53 addition
and
32 deletion
+53
-32
paddle/trainer/Tester.cpp
paddle/trainer/Tester.cpp
+13
-5
paddle/trainer/Tester.h
paddle/trainer/Tester.h
+1
-1
paddle/trainer/TesterConfig.h
paddle/trainer/TesterConfig.h
+8
-3
paddle/trainer/Trainer.cpp
paddle/trainer/Trainer.cpp
+31
-23
未找到文件。
paddle/trainer/Tester.cpp
浏览文件 @
0a0c55d2
...
@@ -90,13 +90,20 @@ void Tester::testOneDataBatch(
...
@@ -90,13 +90,20 @@ void Tester::testOneDataBatch(
testContext_
.
numSamples
+=
dataBatch
.
getSize
();
testContext_
.
numSamples
+=
dataBatch
.
getSize
();
}
}
void
Tester
::
testOnePeriod
()
{
void
Tester
::
testOnePeriod
(
bool
finishPass
)
{
DataBatch
dataBatch
;
DataBatch
dataBatch
;
int64_t
batchSize
=
config_
->
getOptConfig
().
batch_size
();
int64_t
batchSize
=
config_
->
getOptConfig
().
batch_size
();
bool
testAllData
=
bool
testAllData
=
intconfig_
->
testPeriod
==
0
||
intconfig_
->
testAllDataInOnePeriod
;
(
!
finishPass
&&
!
intconfig_
->
testBatchesWhileTraining
)
||
int
batches
=
(
finishPass
&&
!
intconfig_
->
testBatchesWhileEnd
);
testAllData
?
std
::
numeric_limits
<
int
>::
max
()
:
intconfig_
->
testPeriod
;
int
batches
;
if
(
testAllData
)
{
batches
=
std
::
numeric_limits
<
int
>::
max
();
}
else
{
batches
=
finishPass
?
intconfig_
->
testBatchesWhileEnd
:
intconfig_
->
testBatchesWhileTraining
;
}
std
::
vector
<
Argument
>
outArgs
;
std
::
vector
<
Argument
>
outArgs
;
...
@@ -108,7 +115,8 @@ void Tester::testOnePeriod() {
...
@@ -108,7 +115,8 @@ void Tester::testOnePeriod() {
if
(
intconfig_
->
prevBatchState
)
{
if
(
intconfig_
->
prevBatchState
)
{
gradientMachine_
->
resetState
();
gradientMachine_
->
resetState
();
}
}
if
(
testAllData
)
{
if
((
!
finishPass
&&
!
intconfig_
->
testBatchesWhileTraining
)
||
(
finishPass
&&
!
intconfig_
->
testBatchesWhileEnd
))
{
break
;
break
;
}
else
{
}
else
{
num
=
testDataProvider_
->
getNextBatch
(
batchSize
,
&
dataBatch
);
num
=
testDataProvider_
->
getNextBatch
(
batchSize
,
&
dataBatch
);
...
...
paddle/trainer/Tester.h
浏览文件 @
0a0c55d2
...
@@ -67,7 +67,7 @@ public:
...
@@ -67,7 +67,7 @@ public:
* It is convenience to test small set of data when test data set is large and
* It is convenience to test small set of data when test data set is large and
* is training at same time.
* is training at same time.
*/
*/
void
testOnePeriod
();
void
testOnePeriod
(
bool
finishPass
=
true
);
void
startTestPeriod
();
void
startTestPeriod
();
void
finishTestPeriod
();
void
finishTestPeriod
();
void
testOneDataBatch
(
const
DataBatch
&
dataBatch
,
void
testOneDataBatch
(
const
DataBatch
&
dataBatch
,
...
...
paddle/trainer/TesterConfig.h
浏览文件 @
0a0c55d2
...
@@ -38,12 +38,17 @@ struct TesterConfig {
...
@@ -38,12 +38,17 @@ struct TesterConfig {
/**
/**
* indicate test period
* indicate test period
*/
*/
int
testPeriod
;
int
testPeriod
WhileTraining
;
/**
/**
* indicate
whether testing data in one period
* indicate
how many batches are used for testing under training
*/
*/
bool
testAllDataInOnePeriod
;
bool
testBatchesWhileTraining
;
/**
* indicate how many batches are used for testing at pass end
*/
bool
testBatchesWhileEnd
;
/**
/**
* indicate whether to save previous batch state
* indicate whether to save previous batch state
...
...
paddle/trainer/Trainer.cpp
浏览文件 @
0a0c55d2
...
@@ -40,31 +40,28 @@ limitations under the License. */
...
@@ -40,31 +40,28 @@ limitations under the License. */
#include "TrainerConfigHelper.h"
#include "TrainerConfigHelper.h"
P_DEFINE_string
(
config
,
""
,
"Trainer config file"
);
P_DEFINE_string
(
config
,
""
,
"Trainer config file"
);
P_DEFINE_int32
(
test_period
,
0
,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data"
);
P_DEFINE_int32
(
test_batches_while_training
,
0
,
P_DEFINE_int32
(
test_period
,
0
,
"This option was deprecated, use test_period_while_training "
" instead. "
);
P_DEFINE_int32
(
test_period_while_training
,
0
,
"Run test every so many train batches."
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If not 0, test log_period batches."
" If 0, test nothing."
);
P_DEFINE_int32
(
test_batches_while_training
,
1000
,
"test test_batches_while_training batches if test_period != 0."
" If 0, test on all test data"
);
" If 0, test on all test data"
);
P_DEFINE_int32
(
test_batches_while_end
,
0
,
P_DEFINE_int32
(
test_batches_while_end
,
0
,
"Run test every so many train batches."
"test test_batches_while_end batches at pass end."
" 0 for testing after each pass."
" Always run test at pass end."
" If not 0, test log_period batches."
" If not 0, test test_batches_while_end batches."
" If 0, test on all test data"
);
" If 0, test on all test data."
);
P_DEFINE_bool
(
test_all_data_in_one_period
,
false
,
"This option was deprecated, use test_batches_while_training "
"and test_batches_while_end instead"
);
P_DEFINE_bool
(
local
,
true
,
"Train in local mode or not"
);
P_DEFINE_bool
(
local
,
true
,
"Train in local mode or not"
);
P_DEFINE_bool
(
test_all_data_in_one_period
,
false
,
"true will test all data in one test peroid."
"Otherwise test (batch_size * log_peroid) data in one test period."
);
P_DEFINE_int32
(
average_test_period
,
0
,
P_DEFINE_int32
(
average_test_period
,
0
,
"Do test on average parameter every so"
"Do test on average parameter every so"
" many batches. MUST be devided by FLAGS_log_period."
" many batches. MUST be devided by FLAGS_log_period."
...
@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
...
@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
FOR_TIMING
(
globalStat
.
reset
());
FOR_TIMING
(
globalStat
.
reset
());
}
}
if
(
testDataProvider_
&&
FLAGS_test_period
>
0
&&
if
(
testDataProvider_
&&
FLAGS_test_period
_while_training
>
0
&&
trainPassContext_
.
batchId
%
FLAGS_test_period
==
0
)
{
trainPassContext_
.
batchId
%
FLAGS_test_period
_while_training
==
0
)
{
tester_
->
testOnePeriod
();
tester_
->
testOnePeriod
(
false
);
}
}
if
(
FLAGS_saving_period_by_batches
>
0
&&
if
(
FLAGS_saving_period_by_batches
>
0
&&
...
@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
...
@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
0
==
FLAGS_trainer_id
)
{
0
==
FLAGS_trainer_id
)
{
trainerInternal_
.
getParameterUpdater
()
->
catchUpWith
();
trainerInternal_
.
getParameterUpdater
()
->
catchUpWith
();
if
(
testDataProvider_
)
{
if
(
testDataProvider_
)
{
tester_
->
testOnePeriod
();
tester_
->
testOnePeriod
(
false
);
}
}
paramUtil_
->
saveParametersOnePass
(
paramUtil_
->
saveParametersOnePass
(
trainPassContext_
.
passId
,
trainPassContext_
.
passInnerId
);
trainPassContext_
.
passId
,
trainPassContext_
.
passInnerId
);
...
@@ -636,8 +633,19 @@ void Trainer::test() {
...
@@ -636,8 +633,19 @@ void Trainer::test() {
std
::
unique_ptr
<
TesterConfig
>
Trainer
::
createTesterConfig
()
{
std
::
unique_ptr
<
TesterConfig
>
Trainer
::
createTesterConfig
()
{
TesterConfig
*
conf
=
new
TesterConfig
;
TesterConfig
*
conf
=
new
TesterConfig
;
conf
->
testPeriod
=
FLAGS_test_period
;
if
(
FLAGS_test_period
)
{
conf
->
testAllDataInOnePeriod
=
FLAGS_test_all_data_in_one_period
;
LOG
(
WARNING
)
<<
"--test_period was deprecated, use --test_period_while_training"
<<
"--test_batches_while_training --test_batches_while_end instead."
;
}
if
(
FLAGS_test_all_data_in_one_period
)
{
LOG
(
WARNING
)
<<
"--test_all_data_in_one_period was deprecated, use"
<<
" --test_batches_while_training and --test_batches_while_end instead"
;
}
conf
->
testPeriodWhileTraining
=
FLAGS_test_period_while_training
;
conf
->
testBatchesWhileTraining
=
FLAGS_test_batches_while_training
;
conf
->
testBatchesWhileEnd
=
FLAGS_test_batches_while_end
;
conf
->
prevBatchState
=
FLAGS_prev_batch_state
;
conf
->
prevBatchState
=
FLAGS_prev_batch_state
;
conf
->
logPeriod
=
FLAGS_log_period
;
conf
->
logPeriod
=
FLAGS_log_period
;
conf
->
loadsaveParametersInPserver
=
FLAGS_loadsave_parameters_in_pserver
;
conf
->
loadsaveParametersInPserver
=
FLAGS_loadsave_parameters_in_pserver
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录