提交 1cdf149b 编写于 作者: W wanghaoshuang

1. delete PixelSoftmaxLayer and add SwitchOrderLayer

2. Make SwitchOrderLayer support for softmax activation
3. Fix bugs
上级 475dd708
......@@ -13,7 +13,7 @@
# limitations under the License
cmake_minimum_required(VERSION 3.0)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldl -lpthread")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})
......
......@@ -23,12 +23,17 @@ void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
const int num,
const int inC,
const int inH,
const int inW) {
const int inW,
const int argType) {
for (int n = 0; n < num; ++n) {
for (int c = 0; c < inC; ++c) {
for (int h = 0; h < inH; ++h) {
for (int w = 0; w < inW; ++w) {
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
if (argType == ADD_TO) {
outputs[((n * inH + h) * inW + w) * inC + c] += *(inputs++);
} else {
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
}
}
}
}
......@@ -41,12 +46,17 @@ void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
const int num,
const int inH,
const int inW,
const int inC) {
const int inC,
const int argType) {
for (int n = 0; n < num; ++n) {
for (int h = 0; h < inH; ++h) {
for (int w = 0; w < inW; ++w) {
for (int c = 0; c < inC; ++c) {
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
if (argType == ADD_TO) {
outputs[((n * inC + c) * inH + h) * inW + w] += *(inputs++);
} else {
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
}
}
}
}
......@@ -54,23 +64,15 @@ void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
}
/**
* \brief Padding zeros to input according to the specify dimension.
* The struct pad_ contains the padding size in each dimension.
* The input and output is a 4D tensor. In PadFunc, we only
* pad zeros to the 2nd to 4th dimension.
* \brief Switch dimension order of image input.
* The input and output is a 4D tensor. Switch order
* 'batch_size,channels, height, width' to
* order 'batch_size, height, width, channels'.
*
* Argument in this Function:
* \param pad_ A struct object contains the padding size in each dimension.
* It has six integers. The channelStart and channelEnd indicate
* how many zeros to add before and after the input in channel
* dimension. And the heightStart and heightEnd indicate padding
* in height dimension. The widthStart and widthEnd indicate the
* padding in width dimension.
* \param inputs A 4D tensor, only one input.
* \param outputs A 4D tensor, the output value after padding.
*
* \param inputs input data with order 'batch_size,channels, height, width'.
* \param outputs output data with order 'batch_size, height, width, channels'.
*/
template <DeviceType Device>
class NCHW2NHWCFunc : public FunctionBase {
public:
......@@ -84,25 +86,26 @@ public:
size_t inC = inputs[0].shape()[1];
size_t inH = inputs[0].shape()[2];
size_t inW = inputs[0].shape()[3];
typename Tensor<real, Device>::Vector vec(outputs[0].shape().getElements(),
outputs[0].data<real>());
vec.zero();
NCHW2NHWC<Device>(
outputs[0].data<real>(), inputs[0].data<real>(), num, inC, inH, inW);
NCHW2NHWC<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
num,
inC,
inH,
inW,
outputs[0].getArgType());
}
};
/**
* \brief The backward propagation of padding Function. Remove the elements
* in the padding positions of forward.
* \brief Switch dimension order of image input.
* The input and output is a 4D tensor. Switch order
* 'batch_size, height, width, channels' to
* order 'batch_size, channels, height, width'.
*
* Argument in this Function:
* \param pad_ The same meaning as it in PadFunc.
* \param inputs The gradient with respect to the output value of PadFunc.
* \param outputs The gradient with respect to the input value of PadFunc.
* \param inputs input data with order 'batch_size, height, width, channels'.
* \param outputs output data with order 'batch_size, channels, height, width'.
*/
template <DeviceType Device>
class NHWC2NCHWFunc : public FunctionBase {
public:
......@@ -117,8 +120,13 @@ public:
size_t inW = inputs[0].shape()[2];
size_t inC = inputs[0].shape()[3];
NHWC2NCHW<Device>(
outputs[0].data<real>(), inputs[0].data<real>(), num, inH, inW, inC);
NHWC2NCHW<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
num,
inH,
inW,
inC,
outputs[0].getArgType());
}
};
......
......@@ -30,6 +30,7 @@ namespace paddle {
* \param[in] inC channel number of input data.
* \param[in] inH height of input data.
* \param[in] inH with of input data.
* \param[in] argType type of output argument.
*/
template <DeviceType Device>
void NCHW2NHWC(real* outputs,
......@@ -37,7 +38,8 @@ void NCHW2NHWC(real* outputs,
const int num,
const int inC,
const int inH,
const int inW);
const int inW,
const int argtype);
/**
* \brief This funtion switch dimension order of image input.
......@@ -51,6 +53,7 @@ void NCHW2NHWC(real* outputs,
* \param[in] inH height of input data.
* \param[in] inW with of input data.
* \param[in] inC channel number of input data.
* \param[in] argType type of output argument.
*/
template <DeviceType Device>
void NHWC2NCHW(real* inGrad,
......@@ -58,5 +61,6 @@ void NHWC2NCHW(real* inGrad,
const int num,
const int inH,
const int inW,
const int inC);
const int inC,
const int argType);
} // namespace paddle
......@@ -19,7 +19,7 @@ namespace paddle {
__global__ void KeNCHW2NHWC(real* outputs, const real* inputs,
int inC, int inH, int inW,
int nthreads) {
int nthreads, int argType) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
......@@ -28,7 +28,11 @@ __global__ void KeNCHW2NHWC(real* outputs, const real* inputs,
const int n = idx / inW / inH / inC;
const int off = ((n * inH + h) * inW + w) * inC +c;
outputs[off] = inputs[idx];
if (argType == ADD_TO) {
outputs[off] += inputs[idx];
} else {
outputs[off] = inputs[idx];
}
}
}
......@@ -38,18 +42,19 @@ void NCHW2NHWC<DEVICE_TYPE_GPU>(real* outputs,
const int num,
const int inC,
const int inH,
const int inW) {
const int inW,
const int argType) {
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
KeNCHW2NHWC<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, nth);
(outputs, inputs, inC, inH, inW, nth, argType);
CHECK_SYNC("NCHW2NHWC");
}
__global__ void KeNHWC2NCHW(real* outputs, const real* inputs,
int inH, int inW, int inC,
int nthreads) {
int nthreads, int argType) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int c = idx % inC;
......@@ -58,7 +63,11 @@ __global__ void KeNHWC2NCHW(real* outputs, const real* inputs,
const int n = idx / inW / inH / inC;
const int off = ((n * inC + c) * inH + h) * inW + w;
outputs[off] = inputs[idx];
if (argType == ADD_TO) {
outputs[off] += inputs[idx];
} else {
outputs[off] = inputs[idx];
}
}
}
......@@ -68,12 +77,13 @@ void NHWC2NCHW<DEVICE_TYPE_GPU>(real* outputs,
const int num,
const int inH,
const int inW,
const int inC) {
const int inC,
const int argType) {
int nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
KeNHWC2NCHW<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inH, inW, inC, nth);
(outputs, inputs, inH, inW, inC, nth, argType);
CHECK_SYNC("NHWC2NCHW");
}
......
......@@ -12,78 +12,101 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PixelSoftmaxLayer.h"
#include "SwitchOrderLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(pixel_softmax, PixelSoftmaxLayer);
REGISTER_LAYER(switch_order, SwitchOrderLayer);
bool PixelSoftmaxLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool SwitchOrderLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
inH_ =
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
inW_ = img_conf.img_size();
inC_ = img_conf.channels();
createFunction(forward_, "NCHW2NHWC", FuncConfig());
createFunction(backward_, "NHWC2NCHW", FuncConfig());
inDims_ = TensorShape({0, inH_, inW_, inC_});
outDims_ = TensorShape({0, inC_, inH_, inW_});
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
auto& reshape_conf = config_.reshape_conf();
for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) {
LOG(INFO) << "reshape height axis: " << reshape_conf.heightaxis(i);
heightAxis_.push_back(reshape_conf.heightaxis(i));
}
for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) {
LOG(INFO) << "reshape width axis: " << reshape_conf.widthaxis(i);
widthAxis_.push_back(reshape_conf.widthaxis(i));
}
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
return true;
}
void PixelSoftmaxLayer::forward(PassType passType) {
Layer::forward(passType);
void SwitchOrderLayer::setOutDims() {
outDims_.setDim(0, inDims_[0]);
outDims_.setDim(1, inDims_[2]);
outDims_.setDim(2, inDims_[3]);
outDims_.setDim(3, inDims_[1]);
reshapeHeight_ = 1;
for (size_t i = 0; i < heightAxis_.size(); i++) {
reshapeHeight_ *= outDims_[heightAxis_[i]];
}
output_.setFrameHeight(reshapeHeight_);
reshapeWidth_ = 1;
for (size_t i = 0; i < widthAxis_.size(); i++) {
reshapeWidth_ *= outDims_[widthAxis_[i]];
}
output_.setFrameWidth(reshapeWidth_);
LOG(INFO) << "outDims: " << outDims_[0] << "; " << outDims_[1] << ";"
<< outDims_[2] << ";" << outDims_[3];
}
void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
// cout<<"useGpu:"<<useGpu(deviceId_)<<endl;
Matrix::resizeOrCreate(
tmpInput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
Matrix::resizeOrCreate(
tmpOutput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
tmpOutput_->zeroMem();
resetOutput(batchSize, inH_ * inW_ * inC_);
inDims_.setDim(0, batchSize);
outDims_.setDim(0, batchSize);
int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
int channels = totalCount / (inDims_[0] * inDims_[2] * inDims_[3]);
if (channels != 0) inDims_.setDim(1, channels);
LOG(INFO) << "inDims: " << inDims_[0] << "; " << inDims_[1] << ";"
<< inDims_[2] << ";" << inDims_[3];
}
void SwitchOrderLayer::forward(PassType passType) {
Layer::forward(passType);
setInDims();
setOutDims();
resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]);
if (heightAxis_.size() > 0) {
getOutputValue()->reshape(reshapeHeight_, reshapeWidth_);
}
// switch NCHW to NHWC
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inDims_);
outputs.addArg(*tmpInput_, outDims_);
forward_[0]->calc(inputs, outputs);
// softmax forward and save softmax result into tmpMatrix_
tmpInput_->softmax(*tmpOutput_);
// switch NHWC to NCHW
BufferArgs inputs_1;
BufferArgs outputs_1;
inputs_1.addArg(*tmpOutput_, outDims_);
outputs_1.addArg(*getOutputValue(), inDims_);
backward_[0]->calc(inputs_1, outputs_1);
outputs.addArg(*getOutputValue(), outDims_);
nchw2nhwc_[0]->calc(inputs, outputs);
// forwardActivation();
}
void PixelSoftmaxLayer::backward(const UpdateCallback& callback) {
void SwitchOrderLayer::backward(const UpdateCallback& callback) {
(void)callback;
REGISTER_TIMER_INFO("PixelSoftmaxBackward", getName().c_str());
// backwardActivation();
// switch NCHW to NHWC
// switch NHWC to NCHW
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), inDims_);
outputs.addArg(*tmpInput_, outDims_);
forward_[0]->calc(inputs, outputs);
// softmax backward and save grad result into tmpOutput_
tmpInput_->softmaxBackward(*tmpOutput_);
// switch NHWC to NCHW
BufferArgs inputs_1;
BufferArgs outputs_1;
inputs_1.addArg(*tmpInput_, outDims_);
outputs_1.addArg(*getInputGrad(0), inDims_);
backward_[0]->calc(inputs_1, outputs_1);
inputs.addArg(*getOutputGrad(), outDims_);
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
nhwc2nchw_[0]->calc(inputs, outputs);
}
} // namespace paddle
......@@ -21,24 +21,27 @@ namespace paddle {
/**
* \brief This layer calculate softmax in image channel dimension.
*/
class PixelSoftmaxLayer : public Layer {
class SwitchOrderLayer : public Layer {
public:
explicit PixelSoftmaxLayer(const LayerConfig& config) : Layer(config) {}
explicit SwitchOrderLayer(const LayerConfig& config) : Layer(config) {}
~PixelSoftmaxLayer() {}
~SwitchOrderLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void setInDims();
void setOutDims();
protected:
uint32_t inC_;
uint32_t inH_;
uint32_t inW_;
std::vector<std::shared_ptr<FunctionBase>> nchw2nhwc_;
std::vector<std::shared_ptr<FunctionBase>> nhwc2nchw_;
TensorShape inDims_;
TensorShape outDims_;
MatrixPtr tmpInput_;
MatrixPtr tmpOutput_;
std::vector<int> heightAxis_;
std::vector<int> widthAxis_;
size_t reshapeHeight_;
size_t reshapeWidth_;
};
} // namespace paddle
......@@ -1802,7 +1802,7 @@ TEST(Layer, RowConvLayer) {
}
}
TEST(Layer, PixelSoftmaxLayer) {
TEST(Layer, SwitchOrderLayer) {
TestConfig config;
// config input_0
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0});
......@@ -1812,12 +1812,18 @@ TEST(Layer, PixelSoftmaxLayer) {
img->set_img_size(16);
img->set_img_size_y(16);
ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf();
reshape->add_heightaxis(0);
reshape->add_heightaxis(1);
reshape->add_heightaxis(2);
reshape->add_widthaxis(3);
// config softmax layer
config.layerConfig.set_type("pixel_softmax");
config.layerConfig.set_name("pixelSofrmaxLayer");
config.layerConfig.set_type("switch_order");
config.layerConfig.set_name("switchOrderLayer");
for (auto useGpu : {false, true}) {
testLayerGrad(config, "pixel_softmax", 100, false, useGpu, true, 2);
testLayerGrad(config, "switch_order", 100, false, useGpu, true);
}
}
......
......@@ -3385,27 +3385,6 @@ void CpuMatrix::oneHotCrossEntropyWithSelfNormBp(Matrix& output,
real* out = output.getData(); \
for (size_t i = 0; i < numSamples; ++i, grad += dim, out += dim)
void CpuMatrix::softmaxBackward(Matrix& outputV) {
CHECK(!outputV.useGpu()) << "Matrix type are not equal";
size_t height = getHeight();
size_t width = getWidth();
CHECK(height == outputV.getHeight() && width == outputV.getWidth())
<< "Matrix dimensions are not equal";
Matrix::resizeOrCreate(sftmaxDot_,
height_,
width_,
/* trans */ false,
useGpu_);
Matrix::resizeOrCreate(sftmaxSum_,
height_,
1,
/* trans */ false,
useGpu_);
sftmaxDot_->dotMul(*this, outputV);
sftmaxSum_->colMerge(*sftmaxDot_);
softmaxDerivative(outputV, *sftmaxSum_);
}
void CpuMatrix::softmax(Matrix& output) {
CHECK(!output.useGpu());
......
......@@ -1732,7 +1732,6 @@ public:
Matrix& prevGrad2);
void softmax(Matrix& output);
void softmaxBackward(Matrix& outputV);
void sequenceSoftmax(Matrix& output, const IVector& index);
void softmaxDerivative(Matrix& output, Matrix& sftmaxSum);
......
......@@ -266,6 +266,11 @@ message PadConfig {
repeated uint32 pad_w = 4;
}
message ReshapeConfig {
repeated uint32 heightAxis = 1;
repeated uint32 widthAxis = 2;
}
message MultiBoxLossConfig {
required uint32 num_classes = 1;
required float overlap_threshold = 2;
......@@ -476,6 +481,9 @@ message LayerConfig {
// controls the scope of pooling operation. can be set > 0.
// leave empty or set to -1 to disable this stride pooling.
optional int32 seq_pool_stride = 53 [default = -1];
// for switch order layer
optional ReshapeConfig reshape_conf = 54;
}
message EvaluatorConfig {
......
......@@ -3174,20 +3174,13 @@ class RecurrentLayerGroup(LayerBase):
name, 'recurrent_layer_group', 0, inputs=[], device=device)
@config_layer('pixel_softmax')
class PixelSoftmaxLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
super(PixelSoftmaxLayer, self).__init__(
name, 'pixel_softmax', 0, inputs=inputs, **xargs)
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf
image_conf.img_size = input_layer.width
image_conf.img_size_y = input_layer.height
image_conf.channels = input_layer.size / (input_layer.width *
input_layer.height)
self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
image_conf.channels)
@config_layer('switch_order')
class SwitchOrderLayer(LayerBase):
def __init__(self, name, inputs, reshape, **xargs):
super(SwitchOrderLayer, self).__init__(
name, 'switch_order', 0, inputs=inputs, **xargs)
self.conf.reshape_conf.heightAxis_ = reshape['height']
self.conf.reshape_conf.widthAxis_ = reshape['width']
# Deprecated, use a new layer specific class instead
......
......@@ -126,7 +126,7 @@ __all__ = [
'row_conv_layer',
'dropout_layer',
'prelu_layer',
'pixel_softmax_layer',
'switch_order_layer',
]
......@@ -218,7 +218,7 @@ class LayerType(object):
SMOOTH_L1 = 'smooth_l1'
PRELU = 'prelu'
PIXEL_SOFTMAX_LAYER = 'pixel_softmax'
SWITCH_ORDER_LAYER = 'switch_order'
@staticmethod
def is_layer_type(type_name):
......@@ -5881,37 +5881,37 @@ def prelu_layer(input,
@layer_support()
@wrap_name_default('pixel_softmax')
def pixel_softmax_layer(input, name=None, layer_attr=None):
@wrap_name_default('switch_order')
def switch_order_layer(input, name=None, reshape=None, layer_attr=None):
"""
This layer calculate softmax in image channel dimension
This layer switch dimension order of image input.
From order "batchSize, channels, height, width"
to order "batchSize, height, width, channels".
The example usage is:
.. code-block:: python
reshape = {'height':[ 0, 1, 2], 'width':[3]}
switch = switch_order(input=layer, name='switch', reshape=reshape)
prelu = pixel_softmax(input=layer, name='softmax')
:param name: Name of this layer.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput
:param name: Name of this layer.
:type name: basestring
:param reshape: reshape matrix by axises.
:type reshape: Dict
:return: LayerOutput object.
:rtype: LayerOutput
"""
if isinstance(input, LayerOutput):
input = [input]
elif isinstance(input, Projection):
input = [input]
else:
assert isinstance(input, collections.Sequence)
assert isinstance(input, LayerOutput)
l = Layer(
name=name,
inputs=[x.name for x in input],
type=LayerType.PIXEL_SOFTMAX_LAYER,
inputs=input,
reshape=reshape,
type=LayerType.SWITCH_ORDER_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.PIXEL_SOFTMAX_LAYER,
layer_type=LayerType.SWITCH_ORDER_LAYER,
parents=input,
size=l.config.size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册