提交 29f25fbe 编写于 作者: W wanghaoshuang

Add pixel softmax layer for FCN model

1. Add switch function for switching image dimensions order
2. Add CpuMatrix::backwardSoftmax function
3. Add pixel softmax layer, python wrapper and grad_test
上级 98378968
......@@ -37,6 +37,7 @@ if(WITH_GPU)
add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest)
add_simple_unittest(SwitchOpTest)
endif()
add_simple_unittest(ConvOpTest)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "SwitchOp.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW) {
for (int n = 0; n < num; ++n) {
for (int c = 0; c < inC; ++c) {
for (int h = 0; h < inH; ++h) {
for (int w = 0; w < inW; ++w) {
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
}
}
}
}
}
template <>
void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const int num,
const int inH,
const int inW,
const int inC) {
for (int n = 0; n < num; ++n) {
for (int h = 0; h < inH; ++h) {
for (int w = 0; w < inW; ++w) {
for (int c = 0; c < inC; ++c) {
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
}
}
}
}
}
/**
* \brief Padding zeros to input according to the specify dimension.
* The struct pad_ contains the padding size in each dimension.
* The input and output is a 4D tensor. In PadFunc, we only
* pad zeros to the 2nd to 4th dimension.
*
* Argument in this Function:
* \param pad_ A struct object contains the padding size in each dimension.
* It has six integers. The channelStart and channelEnd indicate
* how many zeros to add before and after the input in channel
* dimension. And the heightStart and heightEnd indicate padding
* in height dimension. The widthStart and widthEnd indicate the
* padding in width dimension.
* \param inputs A 4D tensor, only one input.
* \param outputs A 4D tensor, the output value after padding.
*
*/
template <DeviceType Device>
class NCHW2NHWCFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
size_t num = inputs[0].shape()[0];
size_t inC = inputs[0].shape()[1];
size_t inH = inputs[0].shape()[2];
size_t inW = inputs[0].shape()[3];
typename Tensor<real, Device>::Vector vec(outputs[0].shape().getElements(),
outputs[0].data<real>());
vec.zero();
NCHW2NHWC<Device>(
outputs[0].data<real>(), inputs[0].data<real>(), num, inC, inH, inW);
}
};
/**
* \brief The backward propagation of padding Function. Remove the elements
* in the padding positions of forward.
*
* Argument in this Function:
* \param pad_ The same meaning as it in PadFunc.
* \param inputs The gradient with respect to the output value of PadFunc.
* \param outputs The gradient with respect to the input value of PadFunc.
*/
template <DeviceType Device>
class NHWC2NCHWFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
size_t num = inputs[0].shape()[0];
size_t inH = inputs[0].shape()[1];
size_t inW = inputs[0].shape()[2];
size_t inC = inputs[0].shape()[3];
NHWC2NCHW<Device>(
outputs[0].data<real>(), inputs[0].data<real>(), num, inH, inW, inC);
}
};
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief This funtion switch dimension order of image input.
* The input and output is a 4D tensor. Switch order 'batch_size,
*channels, height, width' to
* order 'batch_size, height, width, channels'.
*
* \param[out] outputs save results.
* \param[in] inputs input data.
* \param[in] num batch size of input data.
* \param[in] inC channel number of input data.
* \param[in] inH height of input data.
* \param[in] inH with of input data.
*/
template <DeviceType Device>
void NCHW2NHWC(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW);
/**
* \brief This funtion switch dimension order of image input.
* The input and output is a 4D tensor. Switch order 'batch_size,
*height, width, channels' to
* order 'batch_size, channels, height, width'.
*
* \param[out] inGrad gradients of previous layer.
* \param[in] outGrad output gradients.
* \param[in] num batch size of input data.
* \param[in] inH height of input data.
* \param[in] inW with of input data.
* \param[in] inC channel number of input data.
*/
template <DeviceType Device>
void NHWC2NCHW(real* inGrad,
const real* outGrad,
const int num,
const int inH,
const int inW,
const int inC);
} // namespace paddle
/* Copyright (c) 2016 Paddle
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "SwitchOp.h"
namespace paddle {
__global__ void KeNCHW2NHWC(real* outputs, const real* inputs,
int inC, int inH, int inW,
int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;
const int off = ((n * inH + h) * inW + w) * inC +c;
outputs[off] = inputs[idx];
}
}
template <>
void NCHW2NHWC<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW) {
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
KeNCHW2NHWC<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, nth);
CHECK_SYNC("NCHW2NHWC");
}
__global__ void KeNHWC2NCHW(real* outputs, const real* inputs,
int inH, int inW, int inC,
int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int c = idx % inC;
const int w = (idx / inC) % inW;
const int h = (idx / inC / inW) % inH;
const int n = idx / inW / inH / inC;
const int off = ((n * inC + c) * inH + h) * inW + w;
outputs[off] = inputs[idx];
}
}
template <>
void NHWC2NCHW<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const int num,
const int inH,
const int inW,
const int inC) {
int nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
KeNHWC2NCHW<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inH, inW, inC, nth);
CHECK_SYNC("NHWC2NCHW");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
TEST(Pad, real) {
for (size_t numSamples : {1, 4, 8, 16}) {
for (size_t channels : {1, 4, 8, 16}) {
for (size_t imgSizeH : {1, 4, 8, 16}) {
for (size_t imgSizeW : {1, 4, 8, 16}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {true, false}) {
CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC",
FuncConfig());
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels};
compare.addInputs(
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PixelSoftmaxLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(pixel_softmax, PixelSoftmaxLayer);
bool PixelSoftmaxLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
inH_ =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
inW_ = img_conf.img_size();
inC_ = img_conf.channels();
createFunction(forward_, "NCHW2NHWC", FuncConfig());
createFunction(backward_, "NHWC2NCHW", FuncConfig());
inDims_ = TensorShape({0, inH_, inW_, inC_});
outDims_ = TensorShape({0, inC_, inH_, inW_});
return true;
}
void PixelSoftmaxLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
// cout<<"useGpu:"<<useGpu(deviceId_)<<endl;
Matrix::resizeOrCreate(
tmpInput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
Matrix::resizeOrCreate(
tmpOutput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
tmpOutput_->zeroMem();
resetOutput(batchSize, inH_ * inW_ * inC_);
inDims_.setDim(0, batchSize);
outDims_.setDim(0, batchSize);
// switch NCHW to NHWC
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inDims_);
outputs.addArg(*tmpInput_, outDims_);
forward_[0]->calc(inputs, outputs);
// softmax forward and save softmax result into tmpMatrix_
tmpInput_->softmax(*tmpOutput_);
// switch NHWC to NCHW
BufferArgs inputs_1;
BufferArgs outputs_1;
inputs_1.addArg(*tmpOutput_, outDims_);
outputs_1.addArg(*getOutputValue(), inDims_);
backward_[0]->calc(inputs_1, outputs_1);
}
void PixelSoftmaxLayer::backward(const UpdateCallback& callback) {
(void)callback;
REGISTER_TIMER_INFO("PixelSoftmaxBackward", getName().c_str());
// switch NCHW to NHWC
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), inDims_);
outputs.addArg(*tmpInput_, outDims_);
forward_[0]->calc(inputs, outputs);
// softmax backward and save grad result into tmpOutput_
tmpInput_->softmaxBackward(*tmpOutput_);
// switch NHWC to NCHW
BufferArgs inputs_1;
BufferArgs outputs_1;
inputs_1.addArg(*tmpInput_, outDims_);
outputs_1.addArg(*getInputGrad(0), inDims_);
backward_[0]->calc(inputs_1, outputs_1);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief This layer calculate softmax in image channel dimension.
*/
class PixelSoftmaxLayer : public Layer {
public:
explicit PixelSoftmaxLayer(const LayerConfig& config) : Layer(config) {}
~PixelSoftmaxLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
uint32_t inC_;
uint32_t inH_;
uint32_t inW_;
TensorShape inDims_;
TensorShape outDims_;
MatrixPtr tmpInput_;
MatrixPtr tmpOutput_;
};
} // namespace paddle
......@@ -1792,6 +1792,25 @@ TEST(Layer, RowConvLayer) {
}
}
TEST(Layer, PixelSoftmaxLayer) {
TestConfig config;
// config input_0
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
ImageConfig* img = input->mutable_image_conf();
img->set_channels(4);
img->set_img_size(16);
img->set_img_size_y(16);
// config softmax layer
config.layerConfig.set_type("pixel_softmax");
config.layerConfig.set_name("pixelSofrmaxLayer");
for (auto useGpu : {false, true}) {
testLayerGrad(config, "pixel_softmax", 100, false, useGpu, true, 2);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -3385,6 +3385,27 @@ void CpuMatrix::oneHotCrossEntropyWithSelfNormBp(Matrix& output,
real* out = output.getData(); \
for (size_t i = 0; i < numSamples; ++i, grad += dim, out += dim)
void CpuMatrix::softmaxBackward(Matrix& outputV) {
CHECK(!outputV.useGpu()) << "Matrix type are not equal";
size_t height = getHeight();
size_t width = getWidth();
CHECK(height == outputV.getHeight() && width == outputV.getWidth())
<< "Matrix dimensions are not equal";
Matrix::resizeOrCreate(sftmaxDot_,
height_,
width_,
/* trans */ false,
useGpu_);
Matrix::resizeOrCreate(sftmaxSum_,
height_,
1,
/* trans */ false,
useGpu_);
sftmaxDot_->dotMul(*this, outputV);
sftmaxSum_->colMerge(*sftmaxDot_);
softmaxDerivative(outputV, *sftmaxSum_);
}
void CpuMatrix::softmax(Matrix& output) {
CHECK(!output.useGpu());
......
......@@ -1456,6 +1456,10 @@ public:
};
class CpuMatrix : public Matrix {
private:
MatrixPtr sftmaxSum_;
MatrixPtr sftmaxDot_;
public:
CpuMatrix(size_t height, size_t width, bool trans = false);
CpuMatrix(real* data, size_t height, size_t width, bool trans = false)
......@@ -1728,6 +1732,7 @@ public:
Matrix& prevGrad2);
void softmax(Matrix& output);
void softmaxBackward(Matrix& outputV);
void sequenceSoftmax(Matrix& output, const IVector& index);
void softmaxDerivative(Matrix& output, Matrix& sftmaxSum);
......
......@@ -3171,6 +3171,22 @@ class RecurrentLayerGroup(LayerBase):
name, 'recurrent_layer_group', 0, inputs=[], device=device)
@config_layer('pixel_softmax')
class PixelSoftmaxLayer(LayerBase):
def __init__(self, input, name, **xargs):
super(PixelSoftmaxLayer, self).__init__(
name, 'pixel_softmax', 0, inputs=inputs, **xargs)
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf
image_conf.img_size = input_layer.width
image_conf.img_size_y = input_layer.height
image_conf.channels = input_layer.size / (input_layer.width *
input_layer.height)
self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
image_conf.channels)
# Deprecated, use a new layer specific class instead
@config_func
def Layer(name, type, **xargs):
......
......@@ -217,6 +217,7 @@ class LayerType(object):
SMOOTH_L1 = 'smooth_l1'
PRELU = 'prelu'
PIXEL_SOFTMAX_LAYER = 'pixel_softmax'
@staticmethod
def is_layer_type(type_name):
......@@ -5853,3 +5854,40 @@ def prelu_layer(input,
layer_type=LayerType.PRELU,
parents=input,
size=l.config.size)
@layer_support()
@wrap_name_default('pixel_softmax')
def pixel_softmax_layer(input, name=None, layer_attr=None):
"""
This layer calculate softmax in image channel dimension
The example usage is:
.. code-block:: python
prelu = pixel_softmax(input=layer, name='softmax')
:param name: Name of this layer.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput
:return: LayerOutput object.
:rtype: LayerOutput
"""
if isinstance(input, LayerOutput):
input = [input]
elif isinstance(input, Projection):
input = [input]
else:
assert isinstance(input, collections.Sequence)
l = Layer(
inputs=[x.name for x in input],
name=name,
type=LayerType.PIXEL_SOFTMAX_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.PIXEL_SOFTMAX_LAYER,
parents=input,
size=l.config.size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册