提交 fcf177fc 编写于 作者: Q qijun

reuse code of PoolProjection in PoolProjectionLayer

上级 cdac60f6
......@@ -52,10 +52,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
Layer* PoolLayer::create(const LayerConfig& config) {
CHECK_EQ(config.inputs_size(), 1);
const std::string& pool = config.inputs(0).pool_conf().pool_type();
if (pool == "max-projection") {
return new MaxPoolProjectionLayer(config);
} else if (pool == "avg-projection") {
return new AvgPoolProjectionLayer(config);
if (pool == "max-projection" || pool == "avg-projection") {
return new PoolProjectionLayer(config);
#ifndef PADDLE_ONLY_CPU
} else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config);
......
......@@ -21,9 +21,9 @@ REGISTER_PROJECTION_CREATE_FUNC(pool2, &PoolProjection::create);
PoolProjection* PoolProjection::create(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu) {
const std::string& pool = config.pool_conf().pool_type();
if (pool == "max") {
if (pool == "max-projection") {
return new MaxPoolProjection(config, parameter, useGpu);
} else if (pool == "avg") {
} else if (pool == "avg-projection") {
return new AvgPoolProjection(config, parameter, useGpu);
} else {
LOG(FATAL) << "Unknown pool type: " << pool;
......
......@@ -19,6 +19,7 @@ limitations under the License. */
namespace paddle {
size_t PoolProjectionLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0;
......@@ -41,71 +42,20 @@ size_t PoolProjectionLayer::getSize() {
return layerSize;
}
void MaxPoolProjectionLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one ROW */
MatrixPtr input = getInputValue(0);
int batchSize = input->getHeight();
int size = getSize();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
outV->maxPoolForward(*input, imgSizeH_, imgSizeW_, channels_,
sizeX_, sizeY_, strideY_, stride_,
outputH_, outputW_, confPaddingY_, confPadding_);
}
void MaxPoolProjectionLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}
/* Do derivation */
MatrixPtr outGrad = getOutputGrad();
MatrixPtr inputV = getInputValue(0);
MatrixPtr outV = getOutputValue();
MatrixPtr inputGrad = getInputGrad(0);
inputGrad->maxPoolBackward(*inputV, imgSizeH_, imgSizeW_, *outGrad, *outV,
sizeX_, sizeY_,
strideY_, stride_, outputH_, outputW_, 1, 1,
confPaddingY_, confPadding_);
}
void AvgPoolProjectionLayer::forward(PassType passType) {
void PoolProjectionLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one ROW */
MatrixPtr input = getInputValue(0);
int batchSize = input->getHeight();
const Argument& in = getInput(0);
int batchSize = in.value->getHeight();
int size = getSize();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
outV->avgPoolForward(*input, imgSizeH_, imgSizeW_, channels_,
sizeX_, sizeY_, strideY_, stride_,
outputH_, outputW_, confPaddingY_, confPadding_);
poolProjection_->forward(&in, &output_, passType);
}
void AvgPoolProjectionLayer::backward(const UpdateCallback& callback) {
void PoolProjectionLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}
/* Do derivation */
MatrixPtr outputGrad = getOutputGrad();
MatrixPtr inputGrad = getInputGrad(0);
inputGrad->avgPoolBackward(*outputGrad, imgSizeH_, imgSizeW_,
sizeX_, sizeY_, strideY_, stride_,
outputH_, outputW_, 1, 1,
confPaddingY_, confPadding_);
poolProjection_->backward(callback);
}
} // namespace paddle
......@@ -16,6 +16,7 @@ limitations under the License. */
#pragma once
#include "PoolLayer.h"
#include "PoolProjection.h"
#include "paddle/math/Matrix.h"
#include <vector>
......@@ -27,35 +28,19 @@ class PoolProjectionLayer : public PoolLayer {
protected:
size_t imgSizeH_, imgSizeW_;
size_t outputH_, outputW_;
std::unique_ptr<PoolProjection> poolProjection_;
ProjectionConfig projectionConfig_;
public:
size_t getSize();
explicit PoolProjectionLayer(const LayerConfig& config) : PoolLayer(config) {}
};
/**
* @brief A layer for max pooling
*/
class MaxPoolProjectionLayer : public PoolProjectionLayer {
public:
explicit MaxPoolProjectionLayer(const LayerConfig& config)
: PoolProjectionLayer(config) {}
~MaxPoolProjectionLayer() {}
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
};
/**
* @brief A layer for average pooling
*/
class AvgPoolProjectionLayer : public PoolProjectionLayer {
public:
explicit AvgPoolProjectionLayer(const LayerConfig& config)
: PoolProjectionLayer(config) {}
~AvgPoolProjectionLayer() {}
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
explicit PoolProjectionLayer(const LayerConfig& config)
: PoolLayer(config) {
PoolConfig* conf = projectionConfig_.mutable_pool_conf();
*conf = config_.inputs(0).pool_conf();
poolProjection_.reset(PoolProjection::create(projectionConfig_, nullptr,
useGpu_));
}
};
} // namespace paddle
......@@ -897,12 +897,10 @@ void testSppLayer(const string& poolType, const int pyramidHeight, bool trans,
TEST(Layer, SpatialPyramidPoolLayer) {
for (auto useGpu : {false, true}) {
testSppLayer("avg", 1, false, useGpu);
testSppLayer("avg", 3, false, useGpu);
testSppLayer("avg", 5, false, useGpu);
testSppLayer("max", 1, false, useGpu);
testSppLayer("max", 3, false, useGpu);
testSppLayer("avg", 5, false, useGpu);
for (auto pyramidHeight : {1, 2, 3}) {
testSppLayer("avg-projection", pyramidHeight, false, useGpu);
testSppLayer("max-projection", pyramidHeight, false, useGpu);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册