提交 70e04683 编写于 作者: Q qijun

add getSize method for PoolProjection

上级 bdc9d10a
...@@ -32,6 +32,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config, ...@@ -32,6 +32,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
} }
void MaxPoolProjection::forward() { void MaxPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value; MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value; MatrixPtr outV = out_->value;
outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_, outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,
...@@ -55,6 +57,8 @@ void MaxPoolProjection::backward(const UpdateCallback& callback) { ...@@ -55,6 +57,8 @@ void MaxPoolProjection::backward(const UpdateCallback& callback) {
} }
void AvgPoolProjection::forward() { void AvgPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value; MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value; MatrixPtr outV = out_->value;
outV->avgPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_, outV->avgPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,
......
...@@ -51,6 +51,26 @@ public: ...@@ -51,6 +51,26 @@ public:
static PoolProjection* create(const ProjectionConfig& config, static PoolProjection* create(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu); ParameterPtr parameter, bool useGpu);
const std::string& getPoolType() const { return poolType_; } const std::string& getPoolType() const { return poolType_; }
size_t getSize() {
imgSizeY_ = in_->getFrameHeight();
imgSize_ = in_->getFrameWidth();
const PoolConfig& conf = config_.pool_conf();
if (imgSizeY_ == 0) {
imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
}
if (imgSize_ == 0) {
imgSize_ = conf.img_size();
}
outputY_ = outputSize(imgSizeY_, sizeY_, confPaddingY_, strideY_,
/* caffeMode */ false);
outputX_ = outputSize(imgSize_, sizeX_, confPadding_, stride_,
/* caffeMode */ false);
const_cast<Argument*>(out_)->setFrameHeight(outputY_);
const_cast<Argument*>(out_)->setFrameWidth(outputX_);
return outputY_ * outputX_ * channels_;
}
}; };
class MaxPoolProjection : public PoolProjection { class MaxPoolProjection : public PoolProjection {
......
...@@ -38,8 +38,6 @@ size_t PoolProjectionLayer::getSize() { ...@@ -38,8 +38,6 @@ size_t PoolProjectionLayer::getSize() {
layerSize = outputH_ * outputW_ * channels_; layerSize = outputH_ * outputW_ * channels_;
getOutput().setFrameHeight(outputH_);
getOutput().setFrameWidth(outputW_);
return layerSize; return layerSize;
} }
......
...@@ -70,10 +70,6 @@ size_t SpatialPyramidPoolLayer::getSize() { ...@@ -70,10 +70,6 @@ size_t SpatialPyramidPoolLayer::getSize() {
size_t outputW = (std::pow(4, pyramidHeight_) - 1) / (4 - 1); size_t outputW = (std::pow(4, pyramidHeight_) - 1) / (4 - 1);
layerSize = outputH * outputW * channels_; layerSize = outputH * outputW * channels_;
getOutput().setFrameHeight(outputH);
getOutput().setFrameWidth(outputW);
return layerSize; return layerSize;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册