提交 96212132 编写于 作者: X xzl

add max-pool-with-mask c++ impl

上级 720274da
......@@ -44,14 +44,19 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
with_mask_ = false;
if (poolType_ == "max-pool-with-mask") {
setOutput("mask", &mask_);
with_mask_ = true;
}
return true;
}
Layer* PoolLayer::create(const LayerConfig& config) {
CHECK_EQ(config.inputs_size(), 1);
const std::string& pool = config.inputs(0).pool_conf().pool_type();
if (pool == "max-projection" || pool == "avg-projection") {
if (pool == "max-projection" || pool == "avg-projection" ||
pool == "max-pool-with-mask") {
return new PoolProjectionLayer(config);
#ifdef PADDLE_WITH_CUDA
} else if (CudnnPoolLayer::typeCheck(pool)) {
......
......@@ -37,6 +37,8 @@ protected:
int confPaddingY_;
std::string poolType_;
bool with_mask_;
Argument mask_;
public:
explicit PoolLayer(const LayerConfig& config) : Layer(config) {}
......
......@@ -36,6 +36,10 @@ PoolProjection::PoolProjection(const ProjectionConfig& config,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
with_mask_ = false;
if (poolType_ == "max-pool-with-mask") {
with_mask_ = true;
}
}
size_t PoolProjection::getSize() {
......@@ -73,6 +77,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
return new MaxPoolProjection(config, parameter, useGpu);
} else if (pool == "avg-projection") {
return new AvgPoolProjection(config, parameter, useGpu);
} else if (pool == "max-pool-with-mask") {
return new MaxPoolProjection(config, parameter, useGpu);
} else {
LOG(FATAL) << "Unknown pool type: " << pool;
return nullptr;
......@@ -84,6 +90,10 @@ void MaxPoolProjection::forward() {
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
MatrixPtr maskV = out_->value;
if (with_mask_) {
maskV = mask_->value;
}
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
......@@ -95,7 +105,9 @@ void MaxPoolProjection::forward() {
outputY_,
outputX_,
confPaddingY_,
confPadding_);
confPadding_,
maskV,
with_mask_);
}
void MaxPoolProjection::backward(const UpdateCallback& callback) {
......@@ -168,4 +180,26 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) {
confPaddingY_,
confPadding_);
}
void MaxWithMaskPoolProjection::forward() {
size_t width = getSize();
CHECK_EQ(width, out_->value->getWidth());
MatrixPtr inputV = in_->value;
MatrixPtr outV = out_->value;
MatrixPtr maskV = mask_->value;
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
channels_,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
confPaddingY_,
confPadding_,
maskV,
with_mask_);
}
} // namespace paddle
......@@ -28,6 +28,7 @@ protected:
int confPaddingY_, confPadding_;
size_t channels_;
std::string poolType_;
bool with_mask_;
public:
PoolProjection(const ProjectionConfig& config,
......@@ -37,7 +38,6 @@ public:
static PoolProjection* create(const ProjectionConfig& config,
ParameterPtr parameter,
bool useGpu);
const std::string& getPoolType() const { return poolType_; }
size_t getSize();
......@@ -64,4 +64,15 @@ public:
virtual void forward();
virtual void backward(const UpdateCallback& callback = nullptr);
};
class MaxWithMaskPoolProjection : public MaxPoolProjection {
public:
MaxWithMaskPoolProjection(const ProjectionConfig& config,
ParameterPtr parameter,
bool useGpu)
: MaxPoolProjection(config, parameter, useGpu) {}
virtual void forward();
};
} // namespace paddle
......@@ -51,8 +51,16 @@ void PoolProjectionLayer::forward(PassType passType) {
const Argument& in = getInput(0);
int batchSize = in.value->getHeight();
int size = getSize();
if (with_mask_) {
resetSpecifyOutput(mask_,
batchSize,
size,
/* isValueClean */ false,
/* isGradClean */ true);
}
resetOutput(batchSize, size);
poolProjection_->forward(&in, &output_, passType);
poolProjection_->forward(&in, &output_, &mask_, passType);
}
void PoolProjectionLayer::backward(const UpdateCallback& callback) {
......
......@@ -69,6 +69,17 @@ public:
forward();
}
void forward(const Argument* in,
const Argument* out,
const Argument* mask,
PassType passType) {
in_ = in;
out_ = out;
mask_ = mask;
passType_ = passType;
forward();
}
virtual void prefetch(const Argument* in) {}
virtual void forward() = 0;
virtual void backward(const UpdateCallback& callback) = 0;
......@@ -130,6 +141,8 @@ protected:
const Argument* in_;
/// Store `out` passed to forward()
const Argument* out_;
/// Store `mask` passed to forward()
const Argument* mask_;
/// Store `passType` passed to forward()
PassType passType_;
/// Layer forward function
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册