提交 fcf177fc 编写于 作者: Q qijun

reuse code of PoolProjection in PoolProjectionLayer

上级 cdac60f6
...@@ -52,10 +52,8 @@ bool PoolLayer::init(const LayerMap& layerMap, ...@@ -52,10 +52,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
Layer* PoolLayer::create(const LayerConfig& config) { Layer* PoolLayer::create(const LayerConfig& config) {
CHECK_EQ(config.inputs_size(), 1); CHECK_EQ(config.inputs_size(), 1);
const std::string& pool = config.inputs(0).pool_conf().pool_type(); const std::string& pool = config.inputs(0).pool_conf().pool_type();
if (pool == "max-projection") { if (pool == "max-projection" || pool == "avg-projection") {
return new MaxPoolProjectionLayer(config); return new PoolProjectionLayer(config);
} else if (pool == "avg-projection") {
return new AvgPoolProjectionLayer(config);
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
} else if (CudnnPoolLayer::typeCheck(pool)) { } else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config); return new CudnnPoolLayer(config);
......
...@@ -21,9 +21,9 @@ REGISTER_PROJECTION_CREATE_FUNC(pool2, &PoolProjection::create); ...@@ -21,9 +21,9 @@ REGISTER_PROJECTION_CREATE_FUNC(pool2, &PoolProjection::create);
PoolProjection* PoolProjection::create(const ProjectionConfig& config, PoolProjection* PoolProjection::create(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu) { ParameterPtr parameter, bool useGpu) {
const std::string& pool = config.pool_conf().pool_type(); const std::string& pool = config.pool_conf().pool_type();
if (pool == "max") { if (pool == "max-projection") {
return new MaxPoolProjection(config, parameter, useGpu); return new MaxPoolProjection(config, parameter, useGpu);
} else if (pool == "avg") { } else if (pool == "avg-projection") {
return new AvgPoolProjection(config, parameter, useGpu); return new AvgPoolProjection(config, parameter, useGpu);
} else { } else {
LOG(FATAL) << "Unknown pool type: " << pool; LOG(FATAL) << "Unknown pool type: " << pool;
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
size_t PoolProjectionLayer::getSize() { size_t PoolProjectionLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL); CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0; size_t layerSize = 0;
...@@ -41,71 +42,20 @@ size_t PoolProjectionLayer::getSize() { ...@@ -41,71 +42,20 @@ size_t PoolProjectionLayer::getSize() {
return layerSize; return layerSize;
} }
void MaxPoolProjectionLayer::forward(PassType passType) { void PoolProjectionLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one ROW */
MatrixPtr input = getInputValue(0);
int batchSize = input->getHeight();
int size = getSize();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
outV->maxPoolForward(*input, imgSizeH_, imgSizeW_, channels_,
sizeX_, sizeY_, strideY_, stride_,
outputH_, outputW_, confPaddingY_, confPadding_);
}
void MaxPoolProjectionLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}
/* Do derivation */
MatrixPtr outGrad = getOutputGrad();
MatrixPtr inputV = getInputValue(0);
MatrixPtr outV = getOutputValue();
MatrixPtr inputGrad = getInputGrad(0);
inputGrad->maxPoolBackward(*inputV, imgSizeH_, imgSizeW_, *outGrad, *outV,
sizeX_, sizeY_,
strideY_, stride_, outputH_, outputW_, 1, 1,
confPaddingY_, confPadding_);
}
void AvgPoolProjectionLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
const Argument& in = getInput(0);
/* malloc memory for the output_ if necessary */ int batchSize = in.value->getHeight();
/* note: one sample correspond to one ROW */
MatrixPtr input = getInputValue(0);
int batchSize = input->getHeight();
int size = getSize(); int size = getSize();
resetOutput(batchSize, size); resetOutput(batchSize, size);
poolProjection_->forward(&in, &output_, passType);
MatrixPtr outV = getOutputValue();
outV->avgPoolForward(*input, imgSizeH_, imgSizeW_, channels_,
sizeX_, sizeY_, strideY_, stride_,
outputH_, outputW_, confPaddingY_, confPadding_);
} }
void AvgPoolProjectionLayer::backward(const UpdateCallback& callback) { void PoolProjectionLayer::backward(const UpdateCallback& callback) {
(void)callback; (void)callback;
if (NULL == getInputGrad(0)) { if (NULL == getInputGrad(0)) {
return; return;
} }
/* Do derivation */ poolProjection_->backward(callback);
MatrixPtr outputGrad = getOutputGrad();
MatrixPtr inputGrad = getInputGrad(0);
inputGrad->avgPoolBackward(*outputGrad, imgSizeH_, imgSizeW_,
sizeX_, sizeY_, strideY_, stride_,
outputH_, outputW_, 1, 1,
confPaddingY_, confPadding_);
} }
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#pragma once #pragma once
#include "PoolLayer.h" #include "PoolLayer.h"
#include "PoolProjection.h"
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include <vector> #include <vector>
...@@ -27,35 +28,19 @@ class PoolProjectionLayer : public PoolLayer { ...@@ -27,35 +28,19 @@ class PoolProjectionLayer : public PoolLayer {
protected: protected:
size_t imgSizeH_, imgSizeW_; size_t imgSizeH_, imgSizeW_;
size_t outputH_, outputW_; size_t outputH_, outputW_;
std::unique_ptr<PoolProjection> poolProjection_;
ProjectionConfig projectionConfig_;
public: public:
size_t getSize(); size_t getSize();
explicit PoolProjectionLayer(const LayerConfig& config) : PoolLayer(config) {}
};
/**
* @brief A layer for max pooling
*/
class MaxPoolProjectionLayer : public PoolProjectionLayer {
public:
explicit MaxPoolProjectionLayer(const LayerConfig& config)
: PoolProjectionLayer(config) {}
~MaxPoolProjectionLayer() {}
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
};
/**
* @brief A layer for average pooling
*/
class AvgPoolProjectionLayer : public PoolProjectionLayer {
public:
explicit AvgPoolProjectionLayer(const LayerConfig& config)
: PoolProjectionLayer(config) {}
~AvgPoolProjectionLayer() {}
virtual void forward(PassType passType); virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr); virtual void backward(const UpdateCallback& callback = nullptr);
explicit PoolProjectionLayer(const LayerConfig& config)
: PoolLayer(config) {
PoolConfig* conf = projectionConfig_.mutable_pool_conf();
*conf = config_.inputs(0).pool_conf();
poolProjection_.reset(PoolProjection::create(projectionConfig_, nullptr,
useGpu_));
}
}; };
} // namespace paddle } // namespace paddle
...@@ -897,12 +897,10 @@ void testSppLayer(const string& poolType, const int pyramidHeight, bool trans, ...@@ -897,12 +897,10 @@ void testSppLayer(const string& poolType, const int pyramidHeight, bool trans,
TEST(Layer, SpatialPyramidPoolLayer) { TEST(Layer, SpatialPyramidPoolLayer) {
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
testSppLayer("avg", 1, false, useGpu); for (auto pyramidHeight : {1, 2, 3}) {
testSppLayer("avg", 3, false, useGpu); testSppLayer("avg-projection", pyramidHeight, false, useGpu);
testSppLayer("avg", 5, false, useGpu); testSppLayer("max-projection", pyramidHeight, false, useGpu);
testSppLayer("max", 1, false, useGpu); }
testSppLayer("max", 3, false, useGpu);
testSppLayer("avg", 5, false, useGpu);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册