提交 567871f0 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into batch_norm

...@@ -206,8 +206,8 @@ TEST(Layer, convTransLayerFwd2) { ...@@ -206,8 +206,8 @@ TEST(Layer, convTransLayerFwd2) {
/* filter_size */ 5, /* filter_size */ 5,
result); result);
float resultData[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4, real resultData[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4,
4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1}; 4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1};
result->setData(resultData); result->setData(resultData);
doOneConvtTest(/* imgSize */ 5, doOneConvtTest(/* imgSize */ 5,
/* output_x */ 2, /* output_x */ 2,
...@@ -216,8 +216,8 @@ TEST(Layer, convTransLayerFwd2) { ...@@ -216,8 +216,8 @@ TEST(Layer, convTransLayerFwd2) {
/* filter_size */ 4, /* filter_size */ 4,
result); result);
float resultData2[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4, real resultData2[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4,
4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1}; 4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1};
result->setData(resultData2); result->setData(resultData2);
doOneConvtTest(/* imgSize */ 5, doOneConvtTest(/* imgSize */ 5,
/* output_x */ 2, /* output_x */ 2,
...@@ -226,8 +226,8 @@ TEST(Layer, convTransLayerFwd2) { ...@@ -226,8 +226,8 @@ TEST(Layer, convTransLayerFwd2) {
/* filter_size */ 5, /* filter_size */ 5,
result); result);
float resultData3[] = {1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 4, real resultData3[] = {1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 4,
2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1}; 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1};
result->setData(resultData3); result->setData(resultData3);
doOneConvtTest(/* imgSize */ 5, doOneConvtTest(/* imgSize */ 5,
/* output_x */ 2, /* output_x */ 2,
......
...@@ -106,8 +106,8 @@ TEST(Layer, convParaUnified) { ...@@ -106,8 +106,8 @@ TEST(Layer, convParaUnified) {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
MatrixPtr input, resultCpu, resultGpu; MatrixPtr input, resultCpu, resultGpu;
input = Matrix::create(1, 4 * 4, false, false); input = Matrix::create(1, 4 * 4, false, false);
float inputData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; real inputData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
float param[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 8, 7, 6, 5, 4, 3, 2, 1}; real param[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 8, 7, 6, 5, 4, 3, 2, 1};
input->setData(inputData); input->setData(inputData);
...@@ -137,26 +137,9 @@ TEST(Layer, convParaUnified) { ...@@ -137,26 +137,9 @@ TEST(Layer, convParaUnified) {
checkMatrixEqual(resultCpu, resultGpu); checkMatrixEqual(resultCpu, resultGpu);
input = Matrix::create(1, 3 * 3 * 2, false, false); input = Matrix::create(1, 3 * 3 * 2, false, false);
float inputData2[] = {1, real inputData2[] = {
2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
3, real param2[] = {1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1};
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18};
float param2[] = {1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1};
input->setData(inputData2); input->setData(inputData2);
...@@ -185,7 +168,7 @@ TEST(Layer, convParaUnified) { ...@@ -185,7 +168,7 @@ TEST(Layer, convParaUnified) {
true); true);
checkMatrixEqual(resultCpu, resultGpu); checkMatrixEqual(resultCpu, resultGpu);
float param3[] = {1, 2, 3, 4, 4, 3, 2, 1}; real param3[] = {1, 2, 3, 4, 4, 3, 2, 1};
resultCpu = doOneConvTest(/* imgSize */ 3, resultCpu = doOneConvTest(/* imgSize */ 3,
/* output_x */ 2, /* output_x */ 2,
......
...@@ -38,7 +38,7 @@ public: ...@@ -38,7 +38,7 @@ public:
virtual void startPass() {} virtual void startPass() {}
// called by Trainer then finishing a pass, ruturn true if pass accepted // called by Trainer then finishing a pass, ruturn true if pass accepted
virtual bool finishPass(real cost = 0) { return true; } virtual bool finishPass() { return true; }
// called by Trainer before backward() of a batch // called by Trainer before backward() of a batch
// Return the type of pass it needs. This pass type will be passed // Return the type of pass it needs. This pass type will be passed
...@@ -112,9 +112,9 @@ public: ...@@ -112,9 +112,9 @@ public:
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); }); [&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
} }
virtual bool finishPass(real cost = 0) { virtual bool finishPass() {
syncThreadPool_->execPlusOwner( syncThreadPool_->execPlusOwner(
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); }); [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); });
return true; return true;
} }
......
...@@ -102,9 +102,9 @@ public: ...@@ -102,9 +102,9 @@ public:
* @param cost sum cost during one pass. * @param cost sum cost during one pass.
* @return true if accept (used for owlqn). * @return true if accept (used for owlqn).
*/ */
virtual bool finishPass(real cost) { virtual bool finishPass() {
optimizer_->finishPass(); optimizer_->finishPass();
return ParameterUpdater::finishPass(cost); return ParameterUpdater::finishPass();
} }
/** /**
...@@ -220,9 +220,9 @@ public: ...@@ -220,9 +220,9 @@ public:
averager_->startPass(); averager_->startPass();
SgdLocalUpdater::startPass(); SgdLocalUpdater::startPass();
} }
virtual bool finishPass(real cost) { virtual bool finishPass() {
averager_->finishPass(); averager_->finishPass();
return SgdLocalUpdater::finishPass(cost); return SgdLocalUpdater::finishPass();
} }
/// apply the averaged parameter to PARAMETER_VALUE /// apply the averaged parameter to PARAMETER_VALUE
......
...@@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() { ...@@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() {
} }
} }
bool RemoteParameterUpdater::finishPass(real cost) { bool RemoteParameterUpdater::finishPass() {
if (localUpdater_) { if (localUpdater_) {
localUpdater_->finishPass(); localUpdater_->finishPass();
} }
...@@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() { ...@@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() {
} }
} }
bool SparseRemoteParameterUpdater::finishPass(real cost) { bool SparseRemoteParameterUpdater::finishPass() {
if (config_.algorithm() == TrainAlgorithm::SGD) { if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassFinish(); parameterClient_->waitPassFinish();
} else { } else {
......
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
*/ */
virtual void finishBatch(real cost); virtual void finishBatch(real cost);
virtual void startPass(); virtual void startPass();
virtual bool finishPass(real cost); virtual bool finishPass();
#ifndef PADDLE_DISABLE_TIMER #ifndef PADDLE_DISABLE_TIMER
virtual void setForwardbackwardTime(uint64_t delta) { virtual void setForwardbackwardTime(uint64_t delta) {
...@@ -281,7 +281,7 @@ public: ...@@ -281,7 +281,7 @@ public:
/// send all sparse related parameters to all pservers /// send all sparse related parameters to all pservers
virtual void finishBatch(real cost); virtual void finishBatch(real cost);
virtual void startPass(); virtual void startPass();
virtual bool finishPass(real cost); virtual bool finishPass();
virtual void apply(); virtual void apply();
virtual void restore(); virtual void restore();
......
...@@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() { ...@@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() {
} }
} }
bool SgdThreadUpdater::finishPass(real cost) { bool SgdThreadUpdater::finishPass() {
catchUpWith(); catchUpWith();
for (auto& para : parameters_) { for (auto& para : parameters_) {
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
virtual void startPass(); virtual void startPass();
// Use the finishPass() function of the base optimizer. // Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost); virtual bool finishPass();
virtual void init(const std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize); virtual PassType startBatch(int64_t batchSize);
......
...@@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) { ...@@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) {
trainerInternal_.getGradientMachine()->onPassEnd(); trainerInternal_.getGradientMachine()->onPassEnd();
bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost); bool accepted = trainerInternal_.getParameterUpdater()->finishPass();
globalStat.setThreadInfo(true); globalStat.setThreadInfo(true);
globalStat.printAllStatus(); globalStat.printAllStatus();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册