提交 70e04683 编写于 作者: Q qijun

add getSize method for PoolProjection

上级 bdc9d10a
......@@ -32,6 +32,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
}
void MaxPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,
......@@ -55,6 +57,8 @@ void MaxPoolProjection::backward(const UpdateCallback& callback) {
}
void AvgPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
outV->avgPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,
......
......@@ -51,6 +51,26 @@ public:
static PoolProjection* create(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu);
const std::string& getPoolType() const { return poolType_; }
size_t getSize() {
imgSizeY_ = in_->getFrameHeight();
imgSize_ = in_->getFrameWidth();
const PoolConfig& conf = config_.pool_conf();
if (imgSizeY_ == 0) {
imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
}
if (imgSize_ == 0) {
imgSize_ = conf.img_size();
}
outputY_ = outputSize(imgSizeY_, sizeY_, confPaddingY_, strideY_,
/* caffeMode */ false);
outputX_ = outputSize(imgSize_, sizeX_, confPadding_, stride_,
/* caffeMode */ false);
const_cast<Argument*>(out_)->setFrameHeight(outputY_);
const_cast<Argument*>(out_)->setFrameWidth(outputX_);
return outputY_ * outputX_ * channels_;
}
};
class MaxPoolProjection : public PoolProjection {
......
......@@ -38,8 +38,6 @@ size_t PoolProjectionLayer::getSize() {
layerSize = outputH_ * outputW_ * channels_;
getOutput().setFrameHeight(outputH_);
getOutput().setFrameWidth(outputW_);
return layerSize;
}
......
......@@ -70,10 +70,6 @@ size_t SpatialPyramidPoolLayer::getSize() {
size_t outputW = (std::pow(4, pyramidHeight_) - 1) / (4 - 1);
layerSize = outputH * outputW * channels_;
getOutput().setFrameHeight(outputH);
getOutput().setFrameWidth(outputW);
return layerSize;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册