Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
737f2bf3
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
737f2bf3
编写于
12月 05, 2016
作者:
Y
Yu Yang
提交者:
GitHub
12月 05, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #731 from reyoung/feature/fix_testing_style
Simplify the testOnePeriod method.
上级
82774dbb
ba68704e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
33 addition
and
47 deletion
+33
-47
paddle/api/test/run_tests.sh
paddle/api/test/run_tests.sh
+1
-5
paddle/trainer/Tester.cpp
paddle/trainer/Tester.cpp
+12
-18
paddle/trainer/Trainer.cpp
paddle/trainer/Trainer.cpp
+20
-24
未找到文件。
paddle/api/test/run_tests.sh
浏览文件 @
737f2bf3
...
...
@@ -20,11 +20,7 @@ popd > /dev/null
cd
$SCRIPTPATH
if
[
!
-f
../../dist/
*
.whl
]
;
then
# Swig not compiled.
exit
0
fi
rm
.test_env
-rf
rm
-rf
.test_env
virtualenv .test_env
source
.test_env/bin/activate
...
...
paddle/trainer/Tester.cpp
浏览文件 @
737f2bf3
...
...
@@ -17,22 +17,22 @@ limitations under the License. */
#include <fenv.h>
#include <stdio.h>
#include <iostream>
#include <iomanip>
#include <
s
stream>
#include <
io
stream>
#include <limits>
#include <sstream>
#include <google/protobuf/text_format.h>
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/GlobalConstants.h"
#include "TesterConfig.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "TesterConfig.h"
namespace
paddle
{
...
...
@@ -66,6 +66,9 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config,
}
void
Tester
::
startTestPeriod
()
{
if
(
testDataProvider_
)
{
testDataProvider_
->
reset
();
}
testEvaluator_
->
start
();
testContext_
.
cost
=
0
;
testContext_
.
numSamples
=
0
;
...
...
@@ -87,27 +90,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
void
Tester
::
testOnePeriod
()
{
DataBatch
dataBatch
;
int64_t
batchSize
=
config_
->
getOptConfig
().
batch_size
();
int
batches
=
std
::
numeric_limits
<
int
>::
max
();
std
::
vector
<
Argument
>
outArgs
;
startTestPeriod
();
for
(
int
i
=
0
;
i
<
batches
;
++
i
)
{
int
num
=
testDataProvider_
->
getNextBatch
(
batchSize
,
&
dataBatch
);
if
(
num
==
0
)
{
testDataProvider_
->
reset
();
if
(
intconfig_
->
prevBatchState
)
{
gradientMachine_
->
resetState
();
}
break
;
}
while
(
testDataProvider_
->
getNextBatch
(
batchSize
,
&
dataBatch
)
!=
0
)
{
testOneDataBatch
(
dataBatch
,
&
outArgs
);
}
finishTestPeriod
();
}
void
Tester
::
finishTestPeriod
()
{
if
(
intconfig_
->
prevBatchState
)
{
gradientMachine_
->
resetState
();
}
testEvaluator_
->
finish
();
CHECK_GT
(
testContext_
.
numSamples
,
0
)
<<
"There is no samples in your test batch. Possibly "
...
...
paddle/trainer/Trainer.cpp
浏览文件 @
737f2bf3
...
...
@@ -17,36 +17,38 @@ limitations under the License. */
#include <fenv.h>
#include <stdio.h>
#include <iostream>
#include <iomanip>
#include <
s
stream>
#include <
io
stream>
#include <limits>
#include <sstream>
#include <google/protobuf/text_format.h>
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "RemoteParameterUpdater.h"
#include "TesterConfig.h"
#include "ThreadParameterUpdater.h"
#include "RemoteParameterUpdater.h"
#include "TrainerConfigHelper.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
P_DEFINE_string
(
config
,
""
,
"Trainer config file"
);
P_DEFINE_int32
(
test_period
,
0
,
P_DEFINE_int32
(
test_period
,
0
,
"if equal 0, do test on all test data at the end of "
"each pass. While if equal non-zero, do test on all test "
"data every test_period batches"
);
P_DEFINE_bool
(
test_all_data_in_one_period
,
false
,
"This option was deprecated, since we will always do "
"test on all test set "
);
P_DEFINE_bool
(
test_all_data_in_one_period
,
false
,
"This option was deprecated, since we will always do "
"test on all test set "
);
P_DEFINE_bool
(
local
,
true
,
"Train in local mode or not"
);
...
...
@@ -392,10 +394,6 @@ void Trainer::startTrain() {
dataProvider_
->
reset
();
}
if
(
this
->
testDataProvider_
)
{
this
->
testDataProvider_
->
reset
();
}
trainerInternal_
.
getGradientMachine
()
->
start
(
*
config_
,
dataProvider_
);
}
...
...
@@ -630,16 +628,14 @@ void Trainer::test() { tester_->test(); }
std
::
unique_ptr
<
TesterConfig
>
Trainer
::
createTesterConfig
()
{
TesterConfig
*
conf
=
new
TesterConfig
;
if
(
FLAGS_test_period
)
{
LOG
(
WARNING
)
<<
"The meaning of --test_period is changed: "
<<
"if equal 0, do test on all test data at the end of "
<<
"each pass. While if equal non-zero, do test on all test "
<<
"data every test_period batches "
;
LOG
(
WARNING
)
<<
"The meaning of --test_period is changed: "
<<
"if equal 0, do test on all test data at the end of "
<<
"each pass. While if equal non-zero, do test on all test "
<<
"data every test_period batches "
;
}
if
(
FLAGS_test_all_data_in_one_period
)
{
LOG
(
WARNING
)
<<
"--test_all_data_in_one_period was deprecated, since "
<<
"we will always do test on all test set "
;
LOG
(
WARNING
)
<<
"--test_all_data_in_one_period was deprecated, since "
<<
"we will always do test on all test set "
;
}
conf
->
testPeriod
=
FLAGS_test_period
;
conf
->
prevBatchState
=
FLAGS_prev_batch_state
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录