提交 96212132 编写于 作者: X xzl

add max-pool-with-mask c++ impl

上级 720274da
...@@ -44,14 +44,19 @@ bool PoolLayer::init(const LayerMap& layerMap, ...@@ -44,14 +44,19 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
with_mask_ = false;
if (poolType_ == "max-pool-with-mask") {
setOutput("mask", &mask_);
with_mask_ = true;
}
return true; return true;
} }
Layer* PoolLayer::create(const LayerConfig& config) { Layer* PoolLayer::create(const LayerConfig& config) {
CHECK_EQ(config.inputs_size(), 1); CHECK_EQ(config.inputs_size(), 1);
const std::string& pool = config.inputs(0).pool_conf().pool_type(); const std::string& pool = config.inputs(0).pool_conf().pool_type();
if (pool == "max-projection" || pool == "avg-projection") { if (pool == "max-projection" || pool == "avg-projection" ||
pool == "max-pool-with-mask") {
return new PoolProjectionLayer(config); return new PoolProjectionLayer(config);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} else if (CudnnPoolLayer::typeCheck(pool)) { } else if (CudnnPoolLayer::typeCheck(pool)) {
......
...@@ -37,6 +37,8 @@ protected: ...@@ -37,6 +37,8 @@ protected:
int confPaddingY_; int confPaddingY_;
std::string poolType_; std::string poolType_;
bool with_mask_;
Argument mask_;
public: public:
explicit PoolLayer(const LayerConfig& config) : Layer(config) {} explicit PoolLayer(const LayerConfig& config) : Layer(config) {}
......
...@@ -36,6 +36,10 @@ PoolProjection::PoolProjection(const ProjectionConfig& config, ...@@ -36,6 +36,10 @@ PoolProjection::PoolProjection(const ProjectionConfig& config,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
with_mask_ = false;
if (poolType_ == "max-pool-with-mask") {
with_mask_ = true;
}
} }
size_t PoolProjection::getSize() { size_t PoolProjection::getSize() {
...@@ -73,6 +77,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config, ...@@ -73,6 +77,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
return new MaxPoolProjection(config, parameter, useGpu); return new MaxPoolProjection(config, parameter, useGpu);
} else if (pool == "avg-projection") { } else if (pool == "avg-projection") {
return new AvgPoolProjection(config, parameter, useGpu); return new AvgPoolProjection(config, parameter, useGpu);
} else if (pool == "max-pool-with-mask") {
return new MaxPoolProjection(config, parameter, useGpu);
} else { } else {
LOG(FATAL) << "Unknown pool type: " << pool; LOG(FATAL) << "Unknown pool type: " << pool;
return nullptr; return nullptr;
...@@ -84,6 +90,10 @@ void MaxPoolProjection::forward() { ...@@ -84,6 +90,10 @@ void MaxPoolProjection::forward() {
CHECK_EQ(width, out_->value->getWidth()); CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value; MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value; MatrixPtr outV = out_->value;
MatrixPtr maskV = out_->value;
if (with_mask_) {
maskV = mask_->value;
}
outV->maxPoolForward(*inputV, outV->maxPoolForward(*inputV,
imgSizeY_, imgSizeY_,
imgSize_, imgSize_,
...@@ -95,7 +105,9 @@ void MaxPoolProjection::forward() { ...@@ -95,7 +105,9 @@ void MaxPoolProjection::forward() {
outputY_, outputY_,
outputX_, outputX_,
confPaddingY_, confPaddingY_,
confPadding_); confPadding_,
maskV,
with_mask_);
} }
void MaxPoolProjection::backward(const UpdateCallback& callback) { void MaxPoolProjection::backward(const UpdateCallback& callback) {
...@@ -168,4 +180,26 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) { ...@@ -168,4 +180,26 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) {
confPaddingY_, confPaddingY_,
confPadding_); confPadding_);
} }
void MaxWithMaskPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
MatrixPtr maskV = mask_->value;
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
channels_,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
confPaddingY_,
confPadding_,
maskV,
with_mask_);
}
} // namespace paddle } // namespace paddle
...@@ -28,6 +28,7 @@ protected: ...@@ -28,6 +28,7 @@ protected:
int confPaddingY_, confPadding_; int confPaddingY_, confPadding_;
size_t channels_; size_t channels_;
std::string poolType_; std::string poolType_;
bool with_mask_;
public: public:
PoolProjection(const ProjectionConfig& config, PoolProjection(const ProjectionConfig& config,
...@@ -37,7 +38,6 @@ public: ...@@ -37,7 +38,6 @@ public:
static PoolProjection* create(const ProjectionConfig& config, static PoolProjection* create(const ProjectionConfig& config,
ParameterPtr parameter, ParameterPtr parameter,
bool useGpu); bool useGpu);
const std::string& getPoolType() const { return poolType_; } const std::string& getPoolType() const { return poolType_; }
size_t getSize(); size_t getSize();
...@@ -64,4 +64,15 @@ public: ...@@ -64,4 +64,15 @@ public:
virtual void forward(); virtual void forward();
virtual void backward(const UpdateCallback& callback = nullptr); virtual void backward(const UpdateCallback& callback = nullptr);
}; };
class MaxWithMaskPoolProjection : public MaxPoolProjection {
public:
MaxWithMaskPoolProjection(const ProjectionConfig& config,
ParameterPtr parameter,
bool useGpu)
: MaxPoolProjection(config, parameter, useGpu) {}
virtual void forward();
};
} // namespace paddle } // namespace paddle
...@@ -51,8 +51,16 @@ void PoolProjectionLayer::forward(PassType passType) { ...@@ -51,8 +51,16 @@ void PoolProjectionLayer::forward(PassType passType) {
const Argument& in = getInput(0); const Argument& in = getInput(0);
int batchSize = in.value->getHeight(); int batchSize = in.value->getHeight();
int size = getSize(); int size = getSize();
if (with_mask_) {
resetSpecifyOutput(mask_,
batchSize,
size,
/* isValueClean */ false,
/* isGradClean */ true);
}
resetOutput(batchSize, size); resetOutput(batchSize, size);
poolProjection_->forward(&in, &output_, passType); poolProjection_->forward(&in, &output_, &mask_, passType);
} }
void PoolProjectionLayer::backward(const UpdateCallback& callback) { void PoolProjectionLayer::backward(const UpdateCallback& callback) {
......
...@@ -69,6 +69,17 @@ public: ...@@ -69,6 +69,17 @@ public:
forward(); forward();
} }
void forward(const Argument* in,
const Argument* out,
const Argument* mask,
PassType passType) {
in_ = in;
out_ = out;
mask_ = mask;
passType_ = passType;
forward();
}
virtual void prefetch(const Argument* in) {} virtual void prefetch(const Argument* in) {}
virtual void forward() = 0; virtual void forward() = 0;
virtual void backward(const UpdateCallback& callback) = 0; virtual void backward(const UpdateCallback& callback) = 0;
...@@ -130,6 +141,8 @@ protected: ...@@ -130,6 +141,8 @@ protected:
const Argument* in_; const Argument* in_;
/// Store `out` passed to forward() /// Store `out` passed to forward()
const Argument* out_; const Argument* out_;
/// Store `mask` passed to forward()
const Argument* mask_;
/// Store `passType` passed to forward() /// Store `passType` passed to forward()
PassType passType_; PassType passType_;
/// Layer forward function /// Layer forward function
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册