提交 1cf9800f 编写于 作者: W whs 提交者: GitHub

Merge pull request #2788 from wanghaoshuang/pixel_softmax_layer

Add switch order layer for FCN model
......@@ -44,6 +44,7 @@ if(WITH_GPU)
add_simple_unittest(RowConvOpTest)
add_simple_unittest(BlockExpandOpTest)
add_simple_unittest(CropOpTest)
add_simple_unittest(SwitchOpTest)
endif()
add_simple_unittest(Im2ColTest)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "SwitchOp.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW,
const int argType) {
for (int n = 0; n < num; ++n) {
for (int c = 0; c < inC; ++c) {
for (int h = 0; h < inH; ++h) {
for (int w = 0; w < inW; ++w) {
if (argType == ADD_TO) {
outputs[((n * inH + h) * inW + w) * inC + c] += *(inputs++);
} else {
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
}
}
}
}
}
}
template <>
void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const int num,
const int inH,
const int inW,
const int inC,
const int argType) {
for (int n = 0; n < num; ++n) {
for (int h = 0; h < inH; ++h) {
for (int w = 0; w < inW; ++w) {
for (int c = 0; c < inC; ++c) {
if (argType == ADD_TO) {
outputs[((n * inC + c) * inH + h) * inW + w] += *(inputs++);
} else {
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
}
}
}
}
}
}
/**
* \brief Switch dimension order of image input.
* The input and output is a 4D tensor. Switch order
* 'batch_size,channels, height, width' to
* order 'batch_size, height, width, channels'.
*
* Argument in this Function:
* \param inputs input data with order 'batch_size,channels, height, width'.
* \param outputs output data with order 'batch_size, height, width, channels'.
*/
template <DeviceType Device>
class NCHW2NHWCFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
size_t num = inputs[0].shape()[0];
size_t inC = inputs[0].shape()[1];
size_t inH = inputs[0].shape()[2];
size_t inW = inputs[0].shape()[3];
NCHW2NHWC<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
num,
inC,
inH,
inW,
outputs[0].getArgType());
}
};
/**
* \brief Switch dimension order of image input.
* The input and output is a 4D tensor. Switch order
* 'batch_size, height, width, channels' to
* order 'batch_size, channels, height, width'.
*
* Argument in this Function:
* \param inputs input data with order 'batch_size, height, width, channels'.
* \param outputs output data with order 'batch_size, channels, height, width'.
*/
template <DeviceType Device>
class NHWC2NCHWFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
size_t num = inputs[0].shape()[0];
size_t inH = inputs[0].shape()[1];
size_t inW = inputs[0].shape()[2];
size_t inC = inputs[0].shape()[3];
NHWC2NCHW<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
num,
inH,
inW,
inC,
outputs[0].getArgType());
}
};
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief This funtion switch dimension order of image input.
* The input and output is a 4D tensor. Switch order 'batch_size,
*channels, height, width' to
* order 'batch_size, height, width, channels'.
*
* \param[out] outputs save results.
* \param[in] inputs input data.
* \param[in] num batch size of input data.
* \param[in] inC channel number of input data.
* \param[in] inH height of input data.
* \param[in] inH with of input data.
* \param[in] argType type of output argument.
*/
template <DeviceType Device>
void NCHW2NHWC(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW,
const int argtype);
/**
* \brief This funtion switch dimension order of image input.
* The input and output is a 4D tensor. Switch order 'batch_size,
*height, width, channels' to
* order 'batch_size, channels, height, width'.
*
* \param[out] inGrad gradients of previous layer.
* \param[in] outGrad output gradients.
* \param[in] num batch size of input data.
* \param[in] inH height of input data.
* \param[in] inW with of input data.
* \param[in] inC channel number of input data.
* \param[in] argType type of output argument.
*/
template <DeviceType Device>
void NHWC2NCHW(real* inGrad,
const real* outGrad,
const int num,
const int inH,
const int inW,
const int inC,
const int argType);
} // namespace paddle
/* Copyright (c) 2016 Paddle
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "SwitchOp.h"
#include "hl_base.h"
namespace paddle {
__global__ void KeNCHW2NHWC(real* outputs,
const real* inputs,
int inC,
int inH,
int inW,
int nthreads,
int argType) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;
const int off = ((n * inH + h) * inW + w) * inC + c;
if (argType == ADD_TO) {
outputs[off] += inputs[idx];
} else {
outputs[off] = inputs[idx];
}
}
}
template <>
void NCHW2NHWC<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const int num,
const int inC,
const int inH,
const int inW,
const int argType) {
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
KeNCHW2NHWC<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
outputs, inputs, inC, inH, inW, nth, argType);
CHECK_SYNC("NCHW2NHWC");
}
__global__ void KeNHWC2NCHW(real* outputs,
const real* inputs,
int inH,
int inW,
int inC,
int nthreads,
int argType) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int c = idx % inC;
const int w = (idx / inC) % inW;
const int h = (idx / inC / inW) % inH;
const int n = idx / inW / inH / inC;
const int off = ((n * inC + c) * inH + h) * inW + w;
if (argType == ADD_TO) {
outputs[off] += inputs[idx];
} else {
outputs[off] = inputs[idx];
}
}
}
template <>
void NHWC2NCHW<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const int num,
const int inH,
const int inW,
const int inC,
const int argType) {
int nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
KeNHWC2NCHW<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
outputs, inputs, inH, inW, inC, nth, argType);
CHECK_SYNC("NHWC2NCHW");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
TEST(Pad, real) {
for (size_t numSamples : {1, 4, 8, 16}) {
for (size_t channels : {1, 4, 8, 16}) {
for (size_t imgSizeH : {1, 4, 8, 16}) {
for (size_t imgSizeW : {1, 4, 8, 16}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {true, false}) {
CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC",
FuncConfig());
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels};
compare.addInputs(
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "SwitchOrderLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(switch_order, SwitchOrderLayer);
bool SwitchOrderLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
auto& reshape_conf = config_.reshape_conf();
for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) {
heightAxis_.push_back(reshape_conf.heightaxis(i));
}
for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) {
widthAxis_.push_back(reshape_conf.widthaxis(i));
}
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
return true;
}
void SwitchOrderLayer::setOutDims() {
outDims_.setDim(0, inDims_[0]);
outDims_.setDim(1, inDims_[2]);
outDims_.setDim(2, inDims_[3]);
outDims_.setDim(3, inDims_[1]);
reshapeHeight_ = 1;
for (size_t i = 0; i < heightAxis_.size(); i++) {
reshapeHeight_ *= outDims_[heightAxis_[i]];
}
output_.setFrameHeight(reshapeHeight_);
reshapeWidth_ = 1;
for (size_t i = 0; i < widthAxis_.size(); i++) {
reshapeWidth_ *= outDims_[widthAxis_[i]];
}
output_.setFrameWidth(reshapeWidth_);
}
void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
int channels = totalCount / (inDims_[0] * inDims_[2] * inDims_[3]);
if (channels != 0) inDims_.setDim(1, channels);
}
void SwitchOrderLayer::forward(PassType passType) {
Layer::forward(passType);
setInDims();
setOutDims();
resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]);
if (heightAxis_.size() > 0) {
getOutputValue()->reshape(reshapeHeight_, reshapeWidth_);
getOutputGrad()->reshape(reshapeHeight_, reshapeWidth_);
}
// switch NCHW to NHWC
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inDims_);
outputs.addArg(*getOutputValue(), outDims_);
nchw2nhwc_[0]->calc(inputs, outputs);
forwardActivation();
}
void SwitchOrderLayer::backward(const UpdateCallback& callback) {
(void)callback;
backwardActivation();
// switch NHWC to NCHW
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outDims_);
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
nhwc2nchw_[0]->calc(inputs, outputs);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief This layer calculate softmax in image channel dimension.
*/
class SwitchOrderLayer : public Layer {
public:
explicit SwitchOrderLayer(const LayerConfig& config) : Layer(config) {}
~SwitchOrderLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void setInDims();
void setOutDims();
protected:
std::vector<std::shared_ptr<FunctionBase>> nchw2nhwc_;
std::vector<std::shared_ptr<FunctionBase>> nhwc2nchw_;
TensorShape inDims_;
TensorShape outDims_;
std::vector<int> heightAxis_;
std::vector<int> widthAxis_;
size_t reshapeHeight_;
size_t reshapeWidth_;
};
} // namespace paddle
......@@ -2008,6 +2008,31 @@ TEST(Layer, CropLayer) {
}
}
TEST(Layer, SwitchOrderLayer) {
TestConfig config;
// config input_0
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
ImageConfig* img = input->mutable_image_conf();
img->set_channels(4);
img->set_img_size(16);
img->set_img_size_y(16);
ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf();
reshape->add_heightaxis(0);
reshape->add_heightaxis(1);
reshape->add_heightaxis(2);
reshape->add_widthaxis(3);
// config softmax layer
config.layerConfig.set_type("switch_order");
config.layerConfig.set_name("switchOrderLayer");
for (auto useGpu : {false, true}) {
testLayerGrad(config, "switch_order", 100, false, useGpu, true);
}
}
vector<real> randSampling(real range, int n) {
CHECK_GE(range, n);
vector<real> num(range);
......
......@@ -1616,6 +1616,10 @@ public:
};
class CpuMatrix : public Matrix {
private:
MatrixPtr sftmaxSum_;
MatrixPtr sftmaxDot_;
public:
CpuMatrix(size_t height, size_t width, bool trans = false);
CpuMatrix(real* data, size_t height, size_t width, bool trans = false)
......
......@@ -287,6 +287,11 @@ message PadConfig {
repeated uint32 pad_w = 4;
}
message ReshapeConfig {
repeated uint32 heightAxis = 1;
repeated uint32 widthAxis = 2;
}
message MultiBoxLossConfig {
required uint32 num_classes = 1;
required float overlap_threshold = 2;
......@@ -339,7 +344,6 @@ message LayerInputConfig {
}
message LayerConfig {
required string name = 1;
required string type = 2;
optional uint64 size = 3;
......@@ -516,6 +520,9 @@ message LayerConfig {
optional double delta = 57 [ default = 1.0 ];
optional uint64 depth = 58 [ default = 1 ];
// for switch order layer
optional ReshapeConfig reshape_conf = 59;
}
message EvaluatorConfig {
......
......@@ -3670,6 +3670,15 @@ class RecurrentLayerGroup(LayerBase):
name, 'recurrent_layer_group', 0, inputs=[], device=device)
@config_layer('switch_order')
class SwitchOrderLayer(LayerBase):
def __init__(self, name, inputs, reshape, **xargs):
super(SwitchOrderLayer, self).__init__(
name, 'switch_order', 0, inputs=inputs, **xargs)
self.config.reshape_conf.heightAxis.extend(reshape['height'])
self.config.reshape_conf.widthAxis.extend(reshape['width'])
# Deprecated, use a new layer specific class instead
@config_func
def Layer(name, type, **xargs):
......
......@@ -131,6 +131,7 @@ __all__ = [
'row_conv_layer',
'dropout_layer',
'prelu_layer',
'switch_order_layer',
'gated_unit_layer',
'crop_layer',
'sub_nested_seq_layer',
......@@ -239,6 +240,7 @@ class LayerType(object):
SMOOTH_L1 = 'smooth_l1'
PRELU = 'prelu'
SWITCH_ORDER_LAYER = 'switch_order'
CROP_LAYER = 'crop'
SUB_NESTED_SEQ = 'sub_nested_seq'
CLIP_LAYER = 'clip'
......@@ -6404,6 +6406,48 @@ def gated_unit_layer(input,
layer_attr=layer_attr)
@layer_support()
@wrap_name_default('switch_order')
def switch_order_layer(input,
name=None,
reshape=None,
act=None,
layer_attr=None):
"""
This layer switch dimension order of image input.
From order "batchSize, channels, height, width"
to order "batchSize, height, width, channels".
The example usage is:
.. code-block:: python
reshape = {'height':[ 0, 1, 2], 'width':[3]}
switch = switch_order(input=layer, name='switch', reshape=reshape)
:param input: The input layer.
:type input: LayerOutput
:param name: Name of this layer.
:type name: basestring
:param reshape: reshape matrix by axises.
:type reshape: Dict
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
l = Layer(
name=name,
inputs=input.name,
reshape=reshape,
type=LayerType.SWITCH_ORDER_LAYER,
active_type=act.name,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.SWITCH_ORDER_LAYER,
parents=input,
size=l.config.size)
@wrap_name_default()
@layer_support()
def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册