提交 842d25be 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #1094 from qingqing01/pad_op

Padding Operation
......@@ -382,6 +382,15 @@ sampling_id_layer
:members: sampling_id_layer
:noindex:
Slicing and Joining Layers
==========================
pad_layer
-----------
.. automodule:: paddle.trainer_config_helpers.layers
:members: pad_layer
:noindex:
.. _api_trainer_config_helpers_layers_cost_layers:
Cost Layers
......
......@@ -25,6 +25,7 @@ if(WITH_TESTING)
add_simple_unittest(BufferArgTest)
add_simple_unittest(FunctionTest)
add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest(PadOpTest)
endif()
endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PadOp.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void Pad<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW,
const PadConf& pad) {
int cstart = pad.channelStart, cend = pad.channelEnd;
int hstart = pad.heightStart, hend = pad.heightEnd;
int wstart = pad.widthStart, wend = pad.widthEnd;
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
for (int i = 0; i < num; i++) {
for (int c = 0; c < inC; c++) {
for (int h = 0; h < inH; h++) {
int inoff = ((i * inC + c) * inH + h) * inW;
int outoff =
((i * outC + c + cstart) * outH + h + hstart) * outW + wstart;
memcpy(outputs + outoff, inputs + inoff, inW * sizeof(real));
}
}
}
}
template <>
void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
const real* outGrad,
const int num,
const int inC,
const int inH,
const int inW,
const PadConf& pad) {
int cstart = pad.channelStart, cend = pad.channelEnd;
int hstart = pad.heightStart, hend = pad.heightEnd;
int wstart = pad.widthStart, wend = pad.widthEnd;
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
for (int i = 0; i < num; i++) {
for (int c = 0; c < inC; c++) {
for (int h = 0; h < inH; h++) {
int inoff = ((i * inC + c) * inH + h) * inW;
int outoff =
((i * outC + c + cstart) * outH + h + hstart) * outW + wstart;
CpuVector inG = CpuVector(inW, inGrad + inoff);
CpuVector outG = CpuVector(inW, const_cast<real*>(outGrad + outoff));
inG += outG;
}
}
}
}
/**
* \brief Padding zeros to input according to the specify dimension.
* The struct pad_ contains the padding size in each dimension.
* The input and output is a 4D tensor. In PadFunc, we only
* pad zeros to the 2nd to 4th dimension.
*
* Argument in this Function:
* \param pad_ A struct object contains the padding size in each dimension.
* It has six integers. The channelStart and channelEnd indicate
* how many zeros to add before and after the input in channel
* dimension. And the heightStart and heightEnd indicate padding
* in height dimension. The widthStart and widthEnd indicate the
* padding in width dimension.
* \param inputs A 4D tensor, only one input.
* \param outputs A 4D tensor, the output value after padding.
*
* For example,
* Input(2,2,2,3) = [
* [ [[1,2,3], [3,4,5]],
* [[2,3,5], [1,6,7]] ],
* [ [[4,3,1], [1,8,7]],
* [[3,8,9], [2,3,5]] ]
* ] # the shape is (1,2,2,3)
*
* pad_: if channelStart = channelEnd = 1, others are 0.
* Output(2,4,2,3) = [
* [ [[0,0,0], [0,0,0]],
* [[1,2,3], [3,4,5]],
* [[2,3,5], [1,6,7]],
* [[0,0,0], [0,0,0]] ],
* [ [[0,0,0], [0,0,0]],
* [[4,3,1], [1,8,7]],
* [[3,8,9], [2,3,5]],
* [[0,0,0], [0,0,0]] ]
* ] # the shape is (2,4,2,3)
*
* pad_: if widthStart = 1, widthEnd = 2, others are 0.
* Output(2,2,2,6) = [
* [ [[0,1,2,3,0,0], [0,3,4,5,0,0]],
* [[0,2,3,5,0,0], [0,1,6,7,0,0]] ],
* [ [[0,4,3,1,0,0], [0,1,8,7,0,0]],
* [[0,3,8,9,0,0], [0,2,3,5,0,0]] ],
* ] # the shape is (2,2,2,6)
*
* pad_: if heightStart = 1, heightEnd = 1, others are 0.
* Output(2,2,4,3) = [
* [ [[0,0,0], [1,2,3], [3,4,5], [0,0,0]],
* [[0,0,0], [2,3,5], [1,6,7], [0,0,0]] ],
* [ [[0,0,0], [4,3,1], [1,8,7], [0,0,0]],
* [[0,0,0], [3,8,9], [2,3,5], [0,0,0]] ],
* ] # the shape is (2,2,4,3)
*/
template <DeviceType Device>
class PadFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
pad_.channelStart = config.get<int>("cstart");
pad_.channelEnd = config.get<int>("cend");
pad_.heightStart = config.get<int>("hstart");
pad_.heightEnd = config.get<int>("hend");
pad_.widthStart = config.get<int>("wstart");
pad_.widthEnd = config.get<int>("wend");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
size_t num = inputs[0].shape()[0];
size_t inC = inputs[0].shape()[1];
size_t inH = inputs[0].shape()[2];
size_t inW = inputs[0].shape()[3];
typename Tensor<real, Device>::Vector vec(outputs[0].shape().getElements(),
outputs[0].data<real>());
vec.zero();
Pad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
num,
inC,
inH,
inW,
pad_);
}
private:
PadConf pad_;
};
/**
* \brief The backward propagation of padding Function. Remove the elements
* in the padding positions of forward.
*
* Argument in this Function:
* \param pad_ The same meaning as it in PadFunc.
* \param inputs The gradient with respect to the output value of PadFunc.
* \param outputs The gradient with respect to the input value of PadFunc.
*/
template <DeviceType Device>
class PadGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
pad_.channelStart = config.get<int>("cstart");
pad_.channelEnd = config.get<int>("cend");
pad_.heightStart = config.get<int>("hstart");
pad_.heightEnd = config.get<int>("hend");
pad_.widthStart = config.get<int>("wstart");
pad_.widthEnd = config.get<int>("wend");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
size_t num = outputs[0].shape()[0];
size_t inC = outputs[0].shape()[1];
size_t inH = outputs[0].shape()[2];
size_t inW = outputs[0].shape()[3];
if (outputs[0].getArgType() != ADD_TO) {
// for unit test
typename Tensor<real, Device>::Vector tmp(
outputs[0].shape().getElements(), outputs[0].data<real>());
tmp.zero();
}
PadGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
num,
inC,
inH,
inW,
pad_);
}
private:
PadConf pad_;
};
REGISTER_TYPED_FUNC(Pad, CPU, PadFunc);
REGISTER_TYPED_FUNC(PadGrad, CPU, PadGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(Pad, GPU, PadFunc);
REGISTER_TYPED_FUNC(PadGrad, GPU, PadGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
struct PadConf {
/// how many values to add before the data along channel dimension.
int channelStart;
/// how many values to add after the data along channel dimension.
int channelEnd;
/// how many values to add before the data along height dimension.
int heightStart;
/// how many values to add after the data along height dimension.
int heightEnd;
/// how many values to add before the data along width dimension.
int widthStart;
/// how many values to add after the data along width dimension.
int widthEnd;
};
/**
* \brief This funtion pads zeros to inputs according to the specify dimension.
* The input and output is a 4D tensor. Padding zeros from the 2nd to
* the 4th dimenstion according argument of pad.
*
* \param[out] outputs save results.
* \param[in] inputs input data.
* \param[in] num batch size of input data.
* \param[in] inC channel number of input data.
* \param[in] inH height of input data.
* \param[in] inH with of input data.
* \param[in] pad the padding config, contains the size along the
* specify dimension.
*/
template <DeviceType Device>
void Pad(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW,
const PadConf& pad);
/**
* \brief Padding operation backward.
*
* \param[out] inGrad gradients of previous layer.
* \param[in] outGrad output gradients.
* \param[in] num batch size of input data.
* \param[in] inC channel number of input data.
* \param[in] inH height of input data.
* \param[in] inH with of input data.
* \param[in] pad the padding config, contains the size along the
* specify dimension.
*/
template <DeviceType Device>
void PadGrad(real* inGrad,
const real* outGrad,
const int num,
const int inC,
const int inH,
const int inW,
const PadConf& pad);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "PadOp.h"
namespace paddle {
__global__ void KePad(real* outputs, const real* inputs,
int inC, int inH, int inW,
int padc, int padh, int padw,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;
const int off = ((n * outC + c + padc) * outH + h + padh) * outW + padw + w;
outputs[off] = inputs[idx];
}
}
template <>
void Pad<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW,
const PadConf& pad) {
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
int cstart = pad.channelStart, cend = pad.channelEnd;
int hstart = pad.heightStart, hend = pad.heightEnd;
int wstart = pad.widthStart, wend = pad.widthEnd;
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
KePad<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, cstart, hstart, wstart,
outC, outH, outW, nth);
CHECK_SYNC("Pad");
}
__global__ void KePadDiff(real* inGrad, const real* outGrad,
int inC, int inH, int inW,
int padc, int padh, int padw,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;
const int off = ((n * outC + c + padc) * outH + h + padh) * outW + padw + w;
inGrad[idx] += outGrad[off];
}
}
template <>
void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
const real* outGrad,
const int num,
const int inC,
const int inH,
const int inW,
const PadConf& pad) {
int nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
int cstart = pad.channelStart, cend = pad.channelEnd;
int hstart = pad.heightStart, hend = pad.heightEnd;
int wstart = pad.widthStart, wend = pad.widthEnd;
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
KePadDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inGrad, outGrad, inC, inH, inW, cstart, hstart, wstart,
outC, outH, outW, nth);
CHECK_SYNC("PadGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
TEST(Pad, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
FunctionCompare compare("Pad",
FuncConfig()
.set("cstart", 2)
.set("cend", 3)
.set("hstart", 1)
.set("hend", 2)
.set("wstart", 3)
.set("wend", 2));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, inDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
TEST(PadGrad, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
FunctionCompare compare("PadGrad",
FuncConfig()
.set("cstart", 2)
.set("cend", 3)
.set("hstart", 1)
.set("hend", 2)
.set("wstart", 3)
.set("wend", 2));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, outDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
} // namespace paddle
......@@ -55,6 +55,15 @@ public:
numElements();
}
void reshape(std::initializer_list<size_t> dims) {
ndims_ = dims.size();
if (ndims_ > kMinDims) {
dims_.resize(ndims_);
}
dims_.assign(dims);
numElements();
}
// number of dimensions of the tensor
size_t ndims() const { return ndims_; }
......@@ -82,7 +91,7 @@ private:
// init dims_
void initDims(size_t ndims) {
size_t count = ndims < 4 ? 4 : ndims;
size_t count = ndims < kMinDims ? kMinDims : ndims;
dims_.assign(count, 1);
}
......@@ -92,6 +101,7 @@ private:
// number of elements
size_t nelements_;
std::vector<size_t> dims_;
static const size_t kMinDims = 4;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PadLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(pad, PadLayer);
bool PadLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& pad_conf = config_.inputs(0).pad_conf();
auto& img_conf = pad_conf.image_conf();
CHECK_EQ(config_.inputs_size(), 1);
inDims_ = TensorShape(
{0,
img_conf.channels(),
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(),
img_conf.img_size()});
CHECK_EQ(2, pad_conf.pad_c_size());
CHECK_EQ(2, pad_conf.pad_h_size());
CHECK_EQ(2, pad_conf.pad_w_size());
padc_.push_back(pad_conf.pad_c(0));
padc_.push_back(pad_conf.pad_c(1));
padh_.push_back(pad_conf.pad_h(0));
padh_.push_back(pad_conf.pad_h(1));
padw_.push_back(pad_conf.pad_w(0));
padw_.push_back(pad_conf.pad_w(1));
outDims_ = TensorShape(4);
setOutDims(0);
createFunction(forward_,
"Pad",
FuncConfig()
.set("cstart", padc_[0])
.set("cend", padc_[1])
.set("hstart", padh_[0])
.set("hend", padh_[1])
.set("wstart", padw_[0])
.set("wend", padw_[1]));
createFunction(backward_,
"PadGrad",
FuncConfig()
.set("cstart", padc_[0])
.set("cend", padc_[1])
.set("hstart", padh_[0])
.set("hend", padh_[1])
.set("wstart", padw_[0])
.set("wend", padw_[1]));
return true;
}
void PadLayer::setOutDims(const size_t batchSize) {
outDims_.reshape({batchSize,
inDims_[1] + padc_[0] + padc_[1],
inDims_[2] + padh_[0] + padh_[1],
inDims_[3] + padw_[0] + padw_[1]});
}
void PadLayer::setTensorDim(const size_t batchSize) {
CHECK_EQ(static_cast<int>(inputLayers_.size()), 1);
inDims_.setDim(0, batchSize);
int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
setOutDims(batchSize);
}
void PadLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
setTensorDim(batchSize);
int size = outDims_[1] * outDims_[2] * outDims_[3];
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
REGISTER_TIMER_INFO("PadForward", getName().c_str());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inDims_);
outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
void PadLayer::backward(const UpdateCallback& callback) {
(void)callback;
REGISTER_TIMER_INFO("PadBackward", getName().c_str());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outDims_);
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief This layer pads zeros to inputs according to the specify dimension.
* The input and output is a 4D tensor. Padding zeros from the 2nd to
* the 4th dimenstion according padc_, padh_ and padw_.
*/
class PadLayer : public Layer {
public:
explicit PadLayer(const LayerConfig& config) : Layer(config) {}
~PadLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
protected:
void setOutDims(const size_t batchSize);
void setTensorDim(const size_t batchSize);
std::vector<int> padc_;
std::vector<int> padh_;
std::vector<int> padw_;
TensorShape inDims_;
TensorShape outDims_;
};
} // namespace paddle
......@@ -310,7 +310,11 @@ TEST(Layer, CTCLayer) {
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "ctc", 100, /* trans */ false, /* useGpu */ useGpu);
testLayerGrad(config,
"ctc",
100,
/* trans */ false, /* useGpu */
useGpu);
}
}
......@@ -587,7 +591,11 @@ TEST(Layer, hsigmoidLayer) {
config.layerConfig.add_inputs();
// Not support GPU now
testLayerGrad(config, "hsigmoid", 100, /* trans */ false, /* useGpu */ false);
testLayerGrad(config,
"hsigmoid",
100,
/* trans */ false, /* useGpu */
false);
}
TEST(Layer, multi_cross) {
......@@ -1022,8 +1030,12 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) {
}
TEST(Layer, NormLayer) {
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true);
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false);
testNormLayer("cmrnorm-projection",
/* trans= */ false, /* useGpu= */
true);
testNormLayer("cmrnorm-projection",
/* trans= */ false, /* useGpu= */
false);
}
void setPoolConfig(TestConfig* config,
......@@ -1563,6 +1575,35 @@ TEST(Layer, MultiplexLayer) {
}
}
TEST(Layer, PadLayer) {
TestConfig config;
config.biasSize = 0;
config.layerConfig.set_type("pad");
int c = 4;
int h = 31;
int w = 36;
size_t size = c * h * w;
config.inputDefs.push_back({INPUT_DATA, "layer_0", size, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
PadConfig* pad = input->mutable_pad_conf();
ImageConfig* image = pad->mutable_image_conf();
image->set_channels(c);
image->set_img_size(h);
image->set_img_size_y(w);
pad->add_pad_c(1);
pad->add_pad_c(2);
pad->add_pad_h(2);
pad->add_pad_h(3);
pad->add_pad_w(3);
pad->add_pad_w(5);
for (auto useGpu : {false, true}) {
testLayerGrad(config, "pad", 10, false, useGpu);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -255,6 +255,13 @@ message PriorBoxConfig {
repeated float variance = 4;
}
message PadConfig {
required ImageConfig image_conf = 1;
repeated uint32 pad_c = 2;
repeated uint32 pad_h = 3;
repeated uint32 pad_w = 4;
}
message LayerInputConfig {
required string input_layer_name = 1;
optional string input_parameter_name = 2;
......@@ -271,6 +278,7 @@ message LayerInputConfig {
optional MaxOutConfig maxout_conf = 11;
optional SppConfig spp_conf = 12;
optional PriorBoxConfig priorbox_conf = 13;
optional PadConfig pad_conf = 14;
}
message LayerConfig {
......
......@@ -493,6 +493,7 @@ class Input(Cfg):
block_expand=None,
maxout=None,
spp=None,
pad=None,
format=None,
nnz=None,
is_static=None,
......@@ -844,6 +845,12 @@ class SpatialPyramidPool(Cfg):
self.add_keys(locals())
@config_class
class Pad(Cfg):
def __init__(self, channels, pad_c, pad_h, pad_w):
self.add_keys(locals())
@config_class
class Norm(Cfg):
def __init__(self,
......@@ -1102,7 +1109,7 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
bilinear_conf.out_size_y = bilinear.out_size_y
def parse_pool(pool, input_layer_name, pool_conf):
def parse_pool(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf.pool_type = pool.pool_type
config_assert(pool.pool_type in [
'max-projection', 'avg-projection', 'cudnn-max-pool', 'cudnn-avg-pool'
......@@ -1127,10 +1134,10 @@ def parse_pool(pool, input_layer_name, pool_conf):
pool_conf.padding_y = default(pool.padding_y, pool_conf.padding)
pool_conf.output_x = cnn_output_size(pool_conf.img_size, pool_conf.size_x,
pool_conf.padding, pool_conf.stride,
False)
not ceil_mode)
pool_conf.output_y = cnn_output_size(pool_conf.img_size_y, pool_conf.size_y,
pool_conf.padding_y,
pool_conf.stride_y, False)
pool_conf.stride_y, not ceil_mode)
def parse_spp(spp, input_layer_name, spp_conf):
......@@ -1803,9 +1810,8 @@ class ConvTransLayer(ConvTransLayerBase):
@config_layer('norm')
class NormLayer(LayerBase):
def __init__(self, name, inputs, device=None):
super(NormLayer, self).__init__(
name, 'norm', 0, inputs=inputs, device=device)
def __init__(self, name, inputs, **xargs):
super(NormLayer, self).__init__(name, 'norm', 0, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
norm_conf = self.config.inputs[input_index].norm_conf
......@@ -1817,23 +1823,22 @@ class NormLayer(LayerBase):
@config_layer('pool')
class PoolLayer(LayerBase):
def __init__(self, name, inputs, device=None):
super(PoolLayer, self).__init__(
name, 'pool', 0, inputs=inputs, device=device)
def __init__(self, name, inputs, ceil_mode=True, **xargs):
super(PoolLayer, self).__init__(name, 'pool', 0, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
pool_conf = self.config.inputs[input_index].pool_conf
parse_pool(self.inputs[input_index].pool, input_layer.name,
pool_conf)
pool_conf, ceil_mode)
self.set_cnn_layer(name, pool_conf.output_y, pool_conf.output_x,
pool_conf.channels)
@config_layer('spp')
class SpatialPyramidPoolLayer(LayerBase):
def __init__(self, name, inputs, device=None):
def __init__(self, name, inputs, **xargs):
super(SpatialPyramidPoolLayer, self).__init__(
name, 'spp', 0, inputs=inputs, device=device)
name, 'spp', 0, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
spp_conf = self.config.inputs[input_index].spp_conf
......@@ -1842,6 +1847,25 @@ class SpatialPyramidPoolLayer(LayerBase):
self.set_cnn_layer(name, 1, output_x, spp_conf.image_conf.channels)
@config_layer('pad')
class PadLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
super(PadLayer, self).__init__(name, 'pad', 0, inputs=inputs, **xargs)
pad = self.inputs[0].pad
self.config.inputs[0].pad_conf.pad_c.extend(pad.pad_c)
self.config.inputs[0].pad_conf.pad_h.extend(pad.pad_h)
self.config.inputs[0].pad_conf.pad_w.extend(pad.pad_w)
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].pad_conf.image_conf
parse_image(pad, input_layer.name, image_conf)
out_ch = pad.channels + pad.pad_c[0] + pad.pad_c[1]
out_h = image_conf.img_size_y + pad.pad_h[0] + pad.pad_h[1]
out_w = image_conf.img_size + pad.pad_w[0] + pad.pad_w[1]
self.set_cnn_layer(name, out_h, out_w, out_ch)
self.config.size = out_ch * out_h * out_w
@config_layer('batch_norm')
class BatchNormLayer(LayerBase):
layer_type = 'batch_norm'
......@@ -1851,7 +1875,6 @@ class BatchNormLayer(LayerBase):
inputs,
active_type="linear",
bias=True,
device=None,
use_global_stats=True,
moving_average_fraction=0.9,
batch_norm_type=None,
......@@ -1893,7 +1916,6 @@ class BatchNormLayer(LayerBase):
0,
active_type=active_type,
inputs=inputs,
device=device,
**xargs)
if use_global_stats is not None:
......@@ -1927,9 +1949,9 @@ class BatchNormLayer(LayerBase):
@config_layer('trans')
class TransLayer(LayerBase):
def __init__(self, name, inputs, device=None):
def __init__(self, name, inputs, **xargs):
super(TransLayer, self).__init__(
name, 'trans', 0, inputs=inputs, device=device)
name, 'trans', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'TransLayer must have one and only one input')
......@@ -1938,9 +1960,9 @@ class TransLayer(LayerBase):
@config_layer('resize')
class ResizeLayer(LayerBase):
def __init__(self, name, size, inputs, device=None):
def __init__(self, name, size, inputs, **xargs):
super(ResizeLayer, self).__init__(
name, 'resize', size=size, inputs=inputs, device=device)
name, 'resize', size=size, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'ResizeLayer must have one and only one input')
......@@ -1948,9 +1970,9 @@ class ResizeLayer(LayerBase):
@config_layer('blockexpand')
class BlockExpandLayer(LayerBase):
def __init__(self, name, inputs, device=None):
def __init__(self, name, inputs, **xargs):
super(BlockExpandLayer, self).__init__(
name, 'blockexpand', 0, inputs=inputs, device=device)
name, 'blockexpand', 0, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
parse_block_expand(
......
......@@ -108,6 +108,7 @@ __all__ = [
'print_layer',
'priorbox_layer',
'spp_layer',
'pad_layer',
]
......@@ -170,6 +171,7 @@ class LayerType(object):
BLOCK_EXPAND = "blockexpand"
MAXOUT = "maxout"
SPP_LAYER = "spp"
PAD_LAYER = "pad"
PRINT_LAYER = "print"
PRIORBOX_LAYER = "priorbox"
......@@ -1979,7 +1981,8 @@ def img_pool_layer(input,
layer_attr=None,
pool_size_y=None,
stride_y=None,
padding_y=None):
padding_y=None,
ceil_mode=True):
"""
Image pooling Layer.
......@@ -2010,6 +2013,23 @@ def img_pool_layer(input,
:type stride_y: int|None
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:param ceil_mode: Wether to use ceil mode to calculate output height and with.
Defalut is True. If set false, Otherwise use floor.
- ceil_mode=True:
.. math::
w = 1 + int(ceil(input_width + 2 * padding - pool_size) / float(stride))
h = 1 + int(ceil(input_height + 2 * padding_y - pool_size_y) / float(stride_y))
- ceil_mode=False:
.. math::
w = 1 + int(floor(input_width + 2 * padding - pool_size) / float(stride))
h = 1 + int(floor(input_height + 2 * padding_y - pool_size_y) / float(stride_y))
:type ceil_mode: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -2047,6 +2067,7 @@ def img_pool_layer(input,
stride_y=stride_y,
padding_y=padding_y))
],
ceil_mode=ceil_mode,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
......@@ -3488,9 +3509,6 @@ def conv_projection(input,
groups=1,
param_attr=None):
"""
ConvProjection with a layer as input.
It performs element-wise multiplication with weight.
Different from img_conv_layer and conv_op, conv_projection is an Projection,
which can be used in mixed_layer and conat_layer. It use cudnn to implement
conv and only support GPU mode.
......@@ -3499,7 +3517,7 @@ def conv_projection(input,
.. code-block:: python
proj = conv_projection(img=input1,
proj = conv_projection(input=input1,
filter_size=3,
num_filters=64,
num_channels=64)
......@@ -3582,6 +3600,109 @@ def conv_projection(input,
return proj
@wrap_name_default("pad")
@layer_support()
def pad_layer(input,
pad_c=None,
pad_h=None,
pad_w=None,
name=None,
layer_attr=None):
"""
This operation pads zeros to the input data according to pad_c,pad_h
and pad_w. pad_c, pad_h, pad_w specifies the which dimension and size
of padding. And the input data shape is NCHW.
For example, pad_c=[2,3] means padding 2 zeros before the
input data and 3 zeros after the input data in channel dimension.
pad_h means padding zeros in height dimension. pad_w means padding zeros
in width dimension.
For example,
.. code-block::
input(2,2,2,3) = [
[ [[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]] ],
[ [[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]] ]
]
pad_c=[1,1], pad_h=[0,0], pad_w=[0,0]
output(2,4,2,3) = [
[ [[0,0,0], [0,0,0]],
[[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]],
[[0,0,0], [0,0,0]] ],
[ [[0,0,0], [0,0,0]],
[[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]],
[[0,0,0], [0,0,0]] ]
]
The simply usage is:
.. code-block:: python
pad = pad_layer(input=ipt,
pad_c=[4,4],
pad_h=[0,0],
pad_w=[2,2])
:param input: layer's input.
:type input: LayerOutput
:param pad_c: padding size in channel dimension.
:type pad_c: list|None
:param pad_h: padding size in height dimension.
:type pad_h: list|None
:param pad_w: padding size in width dimension.
:type pad_w: list|None
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:param name: layer name.
:type name: basestring
:return: LayerOutput object.
:rtype: LayerOutput
"""
if pad_c is not None:
assert isinstance(pad_c, collections.Sequence) and len(pad_c) == 2
else:
pad_c = [0, 0]
if pad_h is not None:
assert isinstance(pad_h, collections.Sequence) and len(pad_h) == 2
else:
pad_h = [0, 0]
if pad_w is not None:
assert isinstance(pad_w, collections.Sequence) and len(pad_w) == 2
else:
pad_w = [0, 0]
assert input.num_filters is not None
in_ch = input.num_filters
out_ch = in_ch + pad_c[0] + pad_c[1]
l = Layer(
name=name,
type=LayerType.PAD_LAYER,
inputs=Input(
input.name,
pad=Pad(
channels=in_ch,
pad_c=pad_c,
pad_h=pad_h,
pad_w=pad_w, )),
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
layer_type=LayerType.PAD_LAYER,
parents=[input],
num_filters=out_ch,
size=l.config.size)
@wrap_name_default()
@layer_support()
def conv_shift_layer(a, b, name=None, layer_attr=None):
......
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
data = data_layer(name='data', size=2304, height=48, width=42)
conv = img_conv_layer(
input=data,
filter_size=3,
num_channels=1,
num_filters=16,
padding=1,
act=LinearActivation(),
bias_attr=True)
pool = img_pool_layer(
input=conv, num_channels=8, pool_size=2, stride=2, pool_type=MaxPooling())
pad = pad_layer(input=pool, pad_c=[2, 3], pad_h=[1, 2], pad_w=[3, 1])
outputs(pad)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册