未验证 提交 d02a68c4 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #5768 from NHZlX/add_upsample_layer

Add upsample layer
......@@ -370,4 +370,48 @@ extern void hl_maxout_backward(real* inGrad,
size_t featLen,
size_t groups);
/**
* @brief Upsample forward.
* @param[in] inputData input data.
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
* @param[out] batchSize the batch size of the input.
* @param[in] imgSizeH image height.
* @param[in] imgSizeW image width.
* @param[in] channels the input channels.
* @param[in] outputH the output height.
* @param[in] outputW the output widht.
* @param[out] outputData output data.
*/
extern void hl_upsample_forward(real* inputData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* outputData);
/**
* @brief Upsample backward.
* @param[in] outputGradData the output grad data.
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
* @param[out] batchSize the batch size of the input.
* @param[in] imgSizeH image height.
* @param[in] imgSizeW image width.
* @param[in] channels the input channels.
* @param[in] outputH the output height.
* @param[in] outputW the output widht.
* @param[out] inputGradData the input grad data.
*/
extern void hl_upsample_backward(real* outputGradData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* inputGradData);
#endif // HL_CNN_H_
......@@ -224,4 +224,24 @@ inline void hl_maxout_backward(real* inGrad,
size_t featLen,
size_t group) {}
inline void hl_upsample_forward(real* inputData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* outputData) {}
inline void hl_upsample_backward(real* outputGradData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* inputGradData) {}
#endif // HL_CNN_STUB_H_
......@@ -1028,3 +1028,79 @@ void hl_maxout_backward(real* inGrad,
num_kernels, inGrad, outGrad, idData, size, featLen, groups);
CHECK_SYNC("hl_maxout_backward failed");
}
__global__ void upsampleForwardCompute(real* input_data,
real* mask_data,
size_t nthreads,
size_t in_h,
size_t in_w,
size_t out_h,
size_t out_w,
real* output_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offset = index / (in_w * in_h) * out_h * out_w;
int upsample_idx = static_cast<int>(mask_data[index]);
output_data[offset + upsample_idx] = input_data[index];
}
}
__global__ void upsampleBackwardCompute(real* out_grad,
real* mask_data,
size_t nthreads,
size_t in_h,
size_t in_w,
size_t out_h,
size_t out_w,
real* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offset = index / (in_w * in_h) * out_h * out_w;
int upsample_idx = static_cast<int>(mask_data[index]);
input_grad[index] = out_grad[offset + upsample_idx];
}
}
void hl_upsample_forward(real* inputData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* outputData) {
int num_kernels = batchSize * imgSizeH * imgSizeW * channels;
int blocks = (num_kernels + 1024 - 1) / 1024;
upsampleForwardCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inputData,
maskData,
num_kernels,
imgSizeH,
imgSizeW,
outputH,
outputW,
outputData);
CHECK_SYNC("hl_upsample_forward failed");
}
void hl_upsample_backward(real* outputGradData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* inputGradData) {
int num_kernels = batchSize * imgSizeH * imgSizeW * channels;
int blocks = (num_kernels + 1024 - 1) / 1024;
upsampleBackwardCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(outputGradData,
maskData,
num_kernels,
imgSizeH,
imgSizeW,
outputH,
outputW,
inputGradData);
CHECK_SYNC("hl_upsample_backward failed");
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "UpsampleLayer.h"
#include "iostream"
namespace paddle {
REGISTER_LAYER(upsample, UpsampleLayer);
size_t UpsampleLayer::getOutputSize() {
if (upsampleSize_ == 0) {
upsampleSize_ = imgSize_ * scale_ - static_cast<int>(padOutX_);
upsampleSizeY_ = imgSizeY_ * scaleY_ - static_cast<int>(padOutY_);
}
return upsampleSize_ * upsampleSizeY_ * channels_;
}
bool UpsampleLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2U);
CHECK_EQ(config_.inputs_size(), 2);
const auto& conf = config_.inputs(0).upsample_conf();
const auto& img_conf = conf.image_conf();
imgSizeY_ =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
imgSize_ = img_conf.img_size();
channels_ = img_conf.channels();
CHECK((conf.has_upsample_size()) || (conf.has_scale()))
<< "scale or upsample_size is required.";
if (conf.has_upsample_size()) {
upsampleSize_ = conf.upsample_size();
upsampleSizeY_ = upsampleSize_;
if (conf.has_upsample_size_y()) {
upsampleSizeY_ = conf.upsample_size_y();
}
} else {
if (!conf.has_scale_y()) {
scale_ = scaleY_ = conf.scale_y();
CHECK_GT(static_cast<int>(scale_), 1);
} else {
scale_ = conf.scale();
scaleY_ = conf.scale_y();
}
padOutX_ = conf.pad_out_x();
padOutY_ = conf.pad_out_y();
CHECK(!padOutX_ || scale_ == 2)
<< "Output height padding compensation requires scale_ == 2";
CHECK(!padOutY_ || scaleY_ == 2)
<< "Output width padding compensation requires scaleY_ == 2";
upsampleSize_ = upsampleSizeY_ = 0;
}
return true;
}
void UpsampleLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = getInputValue(0);
MatrixPtr mask = inputLayers_[1]->getOutput("mask").value;
size_t batchSize = input->getHeight();
size_t outSize = getOutputSize();
CHECK_EQ(input->getWidth(), mask->getWidth());
CHECK_EQ(mask->getHeight(), batchSize);
resetOutput(batchSize, outSize);
MatrixPtr output = getOutputValue();
output->upsampleForward(*input,
*mask,
imgSize_,
imgSizeY_,
channels_,
upsampleSize_,
upsampleSizeY_);
}
void UpsampleLayer::backward(const UpdateCallback& callback) {
MatrixPtr mask = inputLayers_[1]->getOutput("mask").value;
MatrixPtr inputGrad = getInputGrad(0);
MatrixPtr outputGrad = getOutputGrad();
inputGrad->upsampleBackward(*outputGrad,
*mask,
imgSize_,
imgSizeY_,
channels_,
upsampleSize_,
upsampleSizeY_);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
/**
* This layer transpose the pooling process.
* It takes two input, the first input is the input data, and
* the second is the mask data from the max-pool-with-mask layer.
*
*/
class UpsampleLayer : public Layer {
public:
explicit UpsampleLayer(const LayerConfig& config) : Layer(config) {}
~UpsampleLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
size_t getOutputSize();
protected:
size_t scale_, scaleY_;
size_t upsampleSize_, upsampleSizeY_;
size_t padOutX_, padOutY_;
size_t imgSize_, imgSizeY_;
size_t channels_;
};
} // namespace paddle
......@@ -27,6 +27,7 @@ gserver_test(test_BatchNorm)
gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)
gserver_test(test_Upsample)
set(PYTHON_PATH
${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "LayerGradUtil.h"
#include "paddle/math/MathUtils.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle;
void setPoolConfig(TestConfig* config,
PoolConfig* pool,
const string& poolType) {
(*config).biasSize = 0;
(*config).layerConfig.set_type("pool");
(*config).layerConfig.set_num_filters(1);
int kw = 2, kh = 2;
int pw = 0, ph = 0;
int sw = 2, sh = 2;
pool->set_pool_type(poolType);
pool->set_channels(2);
pool->set_size_x(kw);
pool->set_size_y(kh);
pool->set_start(0);
pool->set_padding(pw);
pool->set_padding_y(ph);
pool->set_stride(sw);
pool->set_stride_y(sh);
int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false);
int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false);
pool->set_output_x(ow);
pool->set_output_y(oh);
}
LayerPtr doOneUpsampleTest(MatrixPtr& inputMat,
const string& poolType,
bool use_gpu,
real* tempGradData) {
/* prepare maxPoolWithMaskLayer */
TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 128, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
PoolConfig* pool = input->mutable_pool_conf();
pool->set_img_size(8);
pool->set_img_size_y(8);
setPoolConfig(&config, pool, "max-pool-with-mask");
config.layerConfig.set_size(pool->output_x() * pool->output_y() *
pool->channels());
config.layerConfig.set_name("MaxPoolWithMask");
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(config,
&dataLayers,
&datas,
&layerMap,
"MaxPoolWithMask",
1,
false,
use_gpu);
dataLayers[0]->getOutputValue()->copyFrom(*inputMat);
FLAGS_use_gpu = use_gpu;
std::vector<ParameterPtr> parameters;
LayerPtr maxPoolingWithMaskOutputLayer;
initTestLayer(config, &layerMap, &parameters, &maxPoolingWithMaskOutputLayer);
maxPoolingWithMaskOutputLayer->forward(PASS_GC);
/* prepare the upsample layer */
LayerConfig upsampleLayerConfig;
upsampleLayerConfig.set_type("upsample");
LayerInputConfig* input1 = upsampleLayerConfig.add_inputs();
upsampleLayerConfig.add_inputs();
UpsampleConfig* upsampleConfig = input1->mutable_upsample_conf();
upsampleConfig->set_scale(2);
ImageConfig* imageConfig = upsampleConfig->mutable_image_conf();
imageConfig->set_channels(2);
imageConfig->set_img_size(4);
imageConfig->set_img_size_y(4);
upsampleLayerConfig.set_size(2 * 8 * 8);
upsampleLayerConfig.set_name("upsample");
for (size_t i = 0; i < 2; i++) {
LayerInputConfig& inputTemp = *(upsampleLayerConfig.mutable_inputs(i));
inputTemp.set_input_layer_name("MaxPoolWithMask");
}
LayerPtr upsampleLayer;
ParameterMap parameterMap;
upsampleLayer = Layer::create(upsampleLayerConfig);
layerMap[upsampleLayerConfig.name()] = upsampleLayer;
upsampleLayer->init(layerMap, parameterMap);
upsampleLayer->setNeedGradient(true);
upsampleLayer->forward(PASS_GC);
upsampleLayer->getOutputGrad()->copyFrom(tempGradData, 128);
upsampleLayer->backward();
return upsampleLayer;
}
TEST(Layer, maxPoolingWithMaskOutputLayerFwd) {
bool useGpu = false;
MatrixPtr inputMat;
MatrixPtr inputGPUMat;
MatrixPtr tempGradMat;
inputMat = Matrix::create(1, 128, false, useGpu);
inputMat->randomizeUniform();
tempGradMat = Matrix::create(1, 128, false, useGpu);
tempGradMat->randomizeUniform();
real* data = inputMat->getData();
real* tempGradData = tempGradMat->getData();
LayerPtr upsampleLayerCPU =
doOneUpsampleTest(inputMat, "max-pool-with-mask", useGpu, tempGradData);
#ifdef PADDLE_WITH_CUDA
useGpu = true;
inputGPUMat = Matrix::create(1, 128, false, useGpu);
inputGPUMat->copyFrom(data, 128);
LayerPtr upsampleLayerGPU = doOneUpsampleTest(
inputGPUMat, "max-pool-with-mask", useGpu, tempGradData);
checkMatrixEqual(upsampleLayerCPU->getOutput("").value,
upsampleLayerGPU->getOutput("").value);
checkMatrixEqual(upsampleLayerCPU->getPrev(0)->getOutputGrad(),
upsampleLayerGPU->getPrev(0)->getOutputGrad());
#endif
}
......@@ -1024,6 +1024,66 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) {
LOG(INFO) << "the diffCnt is " << diffCnt;
}
void GpuMatrix::upsampleForward(Matrix& input,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW) {
CHECK(input.useGpu_ == true) << "Matrix type are not equal";
CHECK(mask.useGpu_ == true) << "Matrix type are not equal";
real* inputData = input.getData();
real* maskData = mask.getData();
real* outData = data_;
size_t batch = input.getHeight();
CHECK(imgSizeH * imgSizeW * channels == input.getWidth());
CHECK(imgSizeH * imgSizeW * channels == mask.getWidth());
CHECK_EQ(batch, this->getHeight());
CHECK(width_ == outputH * outputW * channels);
hl_upsample_forward(inputData,
maskData,
batch,
imgSizeH,
imgSizeW,
channels,
outputH,
outputW,
outData);
}
void GpuMatrix::upsampleBackward(Matrix& outputGrad,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW) {
CHECK(outputGrad.useGpu_ == true) << "Matrix type are not equal";
CHECK(mask.useGpu_ == true) << "Matrix type are not equal";
real* outputGradData = outputGrad.getData();
real* maskData = mask.getData();
real* inputGradData = data_;
size_t batch = outputGrad.getHeight();
CHECK(imgSizeH * imgSizeW == this->getWidth() / channels);
CHECK_EQ(batch, this->getHeight());
CHECK_EQ(channels * outputH * outputW, outputGrad.getWidth());
hl_upsample_backward(outputGradData,
maskData,
batch,
imgSizeH,
imgSizeW,
channels,
outputH,
outputW,
inputGradData);
}
void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......@@ -1986,6 +2046,72 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ(info, 0);
}
void CpuMatrix::upsampleForward(Matrix& input,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW) {
real* inputData = input.getData();
real* maskData = mask.getData();
real* outData = data_;
size_t inLength = imgSizeH * imgSizeW;
size_t outLength = outputH * outputW;
size_t batch = input.getHeight();
CHECK(inLength == input.getWidth() / channels);
CHECK_EQ(batch, this->getHeight());
CHECK_EQ(channels * outLength, this->getWidth());
for (size_t k = 0; k < batch; k++) {
for (size_t c = 0; c < channels; c++) {
for (size_t i = 0; i < inLength; i++) {
size_t out_index = static_cast<int>(maskData[i]);
if (out_index >= outLength) {
LOG(FATAL) << "upsample index " << out_index << " out of range.";
}
outData[out_index] = inputData[i];
}
inputData += inLength;
maskData += inLength;
outData += outLength;
}
}
}
void CpuMatrix::upsampleBackward(Matrix& outputGrad,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW) {
real* outputGradData = outputGrad.getData();
real* maskData = mask.getData();
real* inputGradData = data_;
size_t inLength = imgSizeH * imgSizeW;
size_t outLength = outputH * outputW;
size_t batch = outputGrad.getHeight();
CHECK(inLength == this->getWidth() / channels);
CHECK_EQ(batch, this->getHeight());
CHECK_EQ(channels * outLength, outputGrad.getWidth());
for (size_t k = 0; k < batch; k++) {
for (size_t c = 0; c < channels; c++) {
for (size_t i = 0; i < inLength; i++) {
size_t out_index = static_cast<int>(maskData[i]);
if (out_index >= outLength) {
LOG(FATAL) << "upsample index " << out_index << " out of range.";
}
inputGradData[i] = outputGradData[out_index];
}
inputGradData += inLength;
maskData += inLength;
outputGradData += outLength;
}
}
}
void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......
......@@ -859,6 +859,26 @@ public:
LOG(FATAL) << "Not implemented";
}
virtual void upsampleForward(Matrix& input,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW) {
LOG(FATAL) << "Not implemeted";
}
virtual void upsampleBackward(Matrix& outputGrad,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW) {
LOG(FATAL) << "Not implemeted";
}
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value, if the maskMatP is not NULL, it will
......@@ -1420,6 +1440,22 @@ public:
void classificationError(Matrix& output, IVector& label, size_t topkSize = 1);
void upsampleForward(Matrix& input,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW);
void upsampleBackward(Matrix& outputGrad,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......@@ -1694,6 +1730,22 @@ public:
MatrixPtr clone(size_t height, size_t width, bool useGpu = false);
void upsampleForward(Matrix& input,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW);
void upsampleBackward(Matrix& outputGrad,
Matrix& mask,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......
......@@ -323,6 +323,16 @@ message ClipConfig {
required double max = 2;
}
message UpsampleConfig {
required ImageConfig image_conf = 1;
optional uint32 scale = 2 [ default = 2 ];
optional uint32 scale_y = 3 [ default = 2 ];
optional bool pad_out_x = 4 [ default = false ];
optional bool pad_out_y = 5 [ default = false ];
optional uint32 upsample_size = 6;
optional uint32 upsample_size_y = 7;
}
message ROIPoolConfig {
required uint32 pooled_width = 1;
required uint32 pooled_height = 2;
......@@ -359,6 +369,7 @@ message LayerInputConfig {
optional ClipConfig clip_conf = 18;
optional ScaleSubRegionConfig scale_sub_region_conf = 19;
optional ROIPoolConfig roi_pool_conf = 20;
optional UpsampleConfig upsample_conf = 21;
}
message LayerConfig {
......
......@@ -471,6 +471,7 @@ class Input(Cfg):
maxout=None,
spp=None,
pad=None,
upsample=None,
format=None,
nnz=None,
is_static=None,
......@@ -983,6 +984,13 @@ class Pad(Cfg):
self.add_keys(locals())
@config_class
class Upsample(Cfg):
def __init__(self, scale, scale_y, pad_out_x, pad_out_y, upsample_size,
upsample_size_y):
self.add_keys(locals())
@config_class
class Norm(Cfg):
def __init__(self,
......@@ -2380,6 +2388,46 @@ class SpatialPyramidPoolLayer(LayerBase):
self.set_cnn_layer(name, 1, output_x, spp_conf.image_conf.channels)
@config_layer('upsample')
class UpsampleLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
super(UpsampleLayer, self).__init__(
name, 'upsample', 0, inputs=inputs, **xargs)
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].upsample_conf.image_conf
image_conf.img_size = input_layer.width
image_conf.img_size_y = input_layer.height
image_conf.channels = input_layer.size / (input_layer.width *
input_layer.height)
upsample = self.inputs[0].upsample
output_x = 0
output_y = 0
output_size = 0
if upsample.scale:
self.config.inputs[0].upsample_conf.scale = upsample.scale
self.config.inputs[0].upsample_conf.scale_y = upsample.scale_y
output_x = input_layer.width * upsample.scale
output_y = input_layer.height * upsample.scale_y
self.config.inputs[0].upsample_conf.pad_out_x = upsample.pad_out_x
self.config.inputs[0].upsample_conf.pad_out_y = upsample.pad_out_y
if upsample.upsample_size:
self.config.inputs[
0].upsample_conf.upsample_size = upsample.upsample_size
self.config.inputs[
0].upsample_conf.upsample_size_y = upsample.upsample_size_y
output_x = upsample.upsample_size
output_y = upsample.upsample_size_y
output_size = image_conf.channels * output_x * output_y
self.set_layer_height_width(output_y, output_x)
self.set_layer_depth(input_layer.depth)
self.set_layer_size(output_size)
@config_layer('pad')
class PadLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......
......@@ -148,6 +148,7 @@ __all__ = [
'resize_layer',
'sub_seq_layer',
'scale_sub_region_layer',
'upsample_layer',
'factorization_machine',
]
......@@ -166,6 +167,7 @@ class LayerType(object):
SEQUENCE_RESHAPE = 'seqreshape'
POOLING_MAX = 'max'
POOLING_AVG = 'average'
UPSAMPLE_LAYER = 'upsample'
FC_LAYER = 'fc'
COST = 'cost'
COSINE_SIM_VEC = 'cos_vm'
......@@ -3014,6 +3016,83 @@ def img_pool3d_layer(input,
size=l.config.size)
@wrap_name_default("upsample")
@layer_support()
def upsample_layer(input,
name=None,
scale=None,
scale_y=None,
upsample_size=None,
upsample_size_y=None,
pad_out_x=False,
pad_out_y=False,
layer_attr=None):
"""
The DePooling process.
Inputs should be a list of length 2. The first input is a layer,
and the second input should be the MaxWithMaskPoolingLayer
The example usage is:
.. code-block:: python
pool1 = paddle.v2.layer.img_pool(input=input, pool_size=2, stride=2,
pool_type=paddle.pooling.MaxWithMask())
upsample = paddle.v2.layer.upsample(input=[layer1, pool1])
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: contains an input layer and a MaxWithMaskPoolingLayer
:type input: list | tuple | collections.Sequence
:param scale: outputSize = scale * inputSize
:type scale: int | list | tuple | .
:param scale_y: scale_y will be equal to scale, if it's value is None,
:type scale: int | None.
:param upsample_size: specify the outputSize.
:type upsample_size: int | list | tuple.
:param upsample_size_y: specify the y dimension outputSize.
:type upsample_size_y: int.
:param pad_out_x: specify exact x dimension size. This parameter only works when scale is 2
:type pad_out_x: bool.
:param pad_out_y: specify exact y dimension size. This parameter only works when scale is 2
:type pad_out_y: bool.
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert (scale is not None) or (upsample_size is not None), \
'scale or upsample_size, there must be one to be designated'
assert len(input) == 2, 'layer input size must be 2'
assert input[1].layer_type == LayerType.POOL_LAYER, \
'the second input should be the MaxPoolWithMaskLayer'
scale_y = scale \
if scale is not None else scale_y
upsample_size_y = upsample_size \
if upsample_size is not None else upsample_size_y
layer_type = LayerType.UPSAMPLE_LAYER
layer = Layer(
name=name,
type=layer_type,
inputs=[
Input(
input[0].name,
upsample=Upsample(scale, scale_y, pad_out_x, pad_out_y,
upsample_size, upsample_size_y)),
Input(input[1].name)
],
**ExtraLayerAttribute.to_kwargs(layer_attr))
sz = layer.config.size
return LayerOutput(name, layer_type=layer_type, parents=input, size=sz)
@wrap_name_default("spp")
@layer_support()
def spp_layer(input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册