diff --git a/doc/ui/api/trainer_config_helpers/layers.rst b/doc/ui/api/trainer_config_helpers/layers.rst index 55f5623b0faef5553064bfc07e4854bed251f623..5bb88b0615c12a44e1506e0bdbb974c16f5584ea 100644 --- a/doc/ui/api/trainer_config_helpers/layers.rst +++ b/doc/ui/api/trainer_config_helpers/layers.rst @@ -73,6 +73,12 @@ img_pool_layer :members: img_pool_layer :noindex: +maxout_layer +------------ +.. automodule:: paddle.trainer_config_helpers.layers + :members: maxout_layer + :noindex: + Norm Layer ========== diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 5d750333e1e35d6097d33d905a02d647c3919eb1..d19f4a4bb310a73d896bc8f4179f41b1a5752e54 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -169,7 +169,7 @@ extern void hl_avgpool_forward( * @brief Maximum pool backward. * * @param[in] frameCnt batch size of input image. - * @param[in] outGrad input data. + * @param[in] outGrad output grad data. * @param[in] channels number of channel. * @param[in] height image height. * @param[in] width image width. @@ -240,4 +240,34 @@ extern void hl_CMRNorm_backward( size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta); +/** + * @brief MaxOut forward. + * + * @param[in] inData input data. + * @param[out] outData output data. + * @param[out] idData output maxId. + * @param[in] batchSize batchSize. + * @param[in] size number of channels * image height * image width. + * @param[in] featLen feature length = image height * image width. + * @param[in] groups number of groups. + */ +extern void hl_maxout_forward( + const real* inData, real* outData, int* idData, + size_t batchSize, size_t size, size_t featLen, size_t groups); + +/** + * @brief MaxOut backward. + * + * @param[out] inGrad input grad data. + * @param[in] outGrad output grad data. + * @param[in] idData output maxId. + * @param[in] batchSize batchSize. + * @param[in] size number of channels * image height * image width. + * @param[in] featLen feature length = image height * image width. + * @param[in] groups number of groups. + */ +extern void hl_maxout_backward( + real* inGrad, const real* outGrad, const int* idData, + size_t batchSize, size_t size, size_t featLen, size_t groups); + #endif /* HL_CNN_H_ */ diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 38e359c3eb2f34e5874187f4b06280a3df901c8e..5f696986e3c8fa19e1f234b03d5ef758c95e3aaf 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -89,4 +89,12 @@ inline void hl_CMRNorm_backward( size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) {} +inline void hl_maxout_forward( + const real* inData, real* outData, int* idData, + size_t batchSize, size_t size, size_t featLen, size_t group) {} + +inline void hl_maxout_backward( + real* inGrad, const real* outGrad, const int* idData, + size_t batchSize, size_t size, size_t featLen, size_t group) {} + #endif // HL_CNN_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index abac83a3e04472fe25bdbe662427aea56c096ad4..baa2fb0d27d749197c10645ff976851ddc38c84f 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -531,3 +531,62 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, height, width, sizeX, alpha, beta, inDiff); CHECK_SYNC("hl_CMRNorm_backward"); } + +__global__ void maxoutFpCompute(size_t nthreads, const real * inData, + real * outData, int* idData, + size_t size, size_t featLen, size_t groups) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < nthreads) { + size_t batch_idx = index / size; + size_t i = index % size; + size_t channel_idx = i / featLen; + size_t feat_idx = i % featLen; + size_t data_idx = (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + real max = inData[data_idx]; + int maxId = 0; + for (size_t g = 1; g < groups; ++g) { + real tmp = inData[data_idx + g * featLen]; + if (tmp > max) { + max = tmp; + maxId = g; + } + } + outData[index] = max; + idData[index] = maxId; + } +} + +void hl_maxout_forward(const real* inData, real* outData, + int* idData, size_t batchSize, size_t size, + size_t featLen, size_t groups) { + int num_kernels = size * batchSize; + int blocks = (num_kernels + 1024 - 1) / 1024; + maxoutFpCompute<<< blocks, 1024, 0, STREAM_DEFAULT>>>( + num_kernels, inData, outData, idData, size, featLen, groups); + CHECK_SYNC("hl_maxout_forward failed"); +} + +__global__ void maxoutBpCompute(size_t nthreads, real* inGrad, + const real* outGrad, const int* idData, + size_t size, size_t featLen, size_t groups) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < nthreads) { + size_t batch_idx = index / size; + size_t i = index % size; + size_t channel_idx = i / featLen; + size_t feat_idx = i % featLen; + size_t newIndex = batch_idx * size; + size_t gradIdx = (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx; + (inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i]; + } +} + +void hl_maxout_backward(real* inGrad, const real* outGrad, + const int* idData, size_t batchSize, size_t size, + size_t featLen, size_t groups) { + int num_kernels = size * batchSize; + int blocks = (num_kernels + 1024 - 1) / 1024; + maxoutBpCompute<<< blocks, 1024, 0, STREAM_DEFAULT >>>( + num_kernels, inGrad, outGrad, idData, size, featLen, groups); + CHECK_SYNC("hl_maxout_backward failed"); +} diff --git a/paddle/gserver/layers/MaxOutLayer.cpp b/paddle/gserver/layers/MaxOutLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..106ab26ba1aae60e0f8f71abd49f15cda031c83a --- /dev/null +++ b/paddle/gserver/layers/MaxOutLayer.cpp @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MaxOutLayer.h" +#include "hl_gpu.h" +#include "hl_cnn.h" + +namespace paddle { + +REGISTER_LAYER(maxout, MaxOutLayer); + +size_t MaxOutLayer::getSize() { + const MaxOutConfig& maxoutConf = config_.inputs(0).maxout_conf(); + imgSizeH_ = inputLayers_[0]->getOutput().getFrameHeight(); + imgSizeW_ = inputLayers_[0]->getOutput().getFrameWidth(); + if (imgSizeH_ == 0) { + imgSizeH_ = maxoutConf.img_size_y(); + } + if (imgSizeW_ == 0) { + imgSizeW_ = maxoutConf.img_size_x(); + } + + featLen_ = imgSizeH_ * imgSizeW_; + size_t layerSize = featLen_ * outputChannels_; + + getOutput().setFrameHeight(imgSizeH_); + getOutput().setFrameWidth(imgSizeW_); + + return layerSize; +} + +bool MaxOutLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + /* the size of inputs for maxout-layer is 1 */ + CHECK_EQ(config_.inputs_size(), 1UL); + + const MaxOutConfig& conf = config_.inputs(0).maxout_conf(); + groups_ = conf.groups(); + channels_ = conf.channels(); + CHECK_EQ(channels_ % groups_, 0UL); + outputChannels_ = channels_ / groups_; + + return true; +} + +void MaxOutLayer::forward(PassType passType) { + Layer::forward(passType); + + /* malloc memory for the output_ if necessary */ + /* note: one sample correspond to one column */ + size_t batchSize = getInput(0).getBatchSize(); + size_t size = getSize(); + resetOutput(batchSize, size); + MatrixPtr inputV = getInputValue(0); + MatrixPtr outV = getOutputValue(); + + IVector::resizeOrCreate(maxoutId_, size * batchSize, useGpu_); + outV->maxoutForward(*inputV, *maxoutId_, outputChannels_, groups_); +} + +void MaxOutLayer::backward(const UpdateCallback& callback) { + (void)callback; + + /* Do derivation */ + MatrixPtr inputG = getInputGrad(0); + MatrixPtr outG = getOutputGrad(); + + if (inputG) { + inputG->maxoutBackward(*outG, *maxoutId_, outputChannels_, groups_); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MaxOutLayer.h b/paddle/gserver/layers/MaxOutLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..9011a5c332b17a2f697380b1afb40ad9de504b91 --- /dev/null +++ b/paddle/gserver/layers/MaxOutLayer.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * A layer to do max out on conv layer output. + * Input: output of a conv layer. + * Output: feature map size same as input. Channel is (input channel) / groups. + * So the num of channels should be able to devided by groups. + * + * The config file api is maxout_layer. + */ + +class MaxOutLayer : public Layer { +protected: + size_t groups_; + size_t imgSizeH_, imgSizeW_; + /// outputChannels_ = channels_ / groups_ + size_t channels_, outputChannels_; + /// feature length = imgSizeH_ * imgSizeW_ + size_t featLen_; + IVectorPtr maxoutId_; + +public: + /// return imgSizeH_ * imgSizeW_ * outputChannels_; + size_t getSize(); + + explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {} + virtual ~MaxOutLayer() {} + + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + void backward(const UpdateCallback& callback = nullptr); +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/rnn_data_provider.py b/paddle/gserver/tests/rnn_data_provider.py index 5c3b062309c51f3549f0dde1c6aed3be94619ef5..321c78cb1741bcfcbd7df2fd83ff6ba5ba910971 100644 --- a/paddle/gserver/tests/rnn_data_provider.py +++ b/paddle/gserver/tests/rnn_data_provider.py @@ -14,12 +14,15 @@ from paddle.trainer.PyDataProvider2 import * +# Note that each config should has an independent provider +# in current design of PyDataProvider2. +####################################################### data = [ [[[1, 3, 2], [4, 5, 2]], 0], [[[0, 2], [2, 5], [0, 1, 2]], 1], ] - +# Used for sequence_nest_rnn.conf @provider(input_types=[integer_value_sub_sequence(10), integer_value(3)], should_shuffle=False) @@ -27,7 +30,7 @@ def process_subseq(settings, file_name): for d in data: yield d - +# Used for sequence_rnn.conf @provider(input_types=[integer_value_sequence(10), integer_value(3)], should_shuffle=False) @@ -38,11 +41,32 @@ def process_seq(settings, file_name): seq += subseq yield seq, d[1] +# Used for sequence_nest_rnn_multi_input.conf +@provider(input_types=[integer_value_sub_sequence(10), + integer_value(3)], + should_shuffle=False) +def process_subseq2(settings, file_name): + for d in data: + yield d + +# Used for sequence_rnn_multi_input.conf +@provider(input_types=[integer_value_sequence(10), + integer_value(3)], + should_shuffle=False) +def process_seq2(settings, file_name): + for d in data: + seq = [] + for subseq in d[0]: + seq += subseq + yield seq, d[1] + +########################################################### data2 = [ [[[1, 2], [4, 5, 2]], [[5, 4, 1], [3, 1]] ,0], [[[0, 2], [2, 5], [0, 1, 2]],[[1, 5], [4], [2, 3, 6, 1]], 1], ] +# Used for sequence_nest_rnn_multi_unequalength_inputs.conf @provider(input_types=[integer_value_sub_sequence(10), integer_value_sub_sequence(10), integer_value(2)], @@ -52,6 +76,7 @@ def process_unequalength_subseq(settings, file_name): yield d +# Used for sequence_rnn_multi_unequalength_inputs.conf @provider(input_types=[integer_value_sequence(10), integer_value_sequence(10), integer_value(2)], diff --git a/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf b/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf index e8222cef525a806a6201b7290f75138c94bd0aaf..0614958b4719ddb2098dc495c4a6c615f2628457 100644 --- a/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf +++ b/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf @@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import * define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list', test_list=None, module='rnn_data_provider', - obj='process_subseq') + obj='process_subseq2') settings(batch_size=2, learning_rate=0.01) diff --git a/paddle/gserver/tests/sequence_rnn_multi_input.conf b/paddle/gserver/tests/sequence_rnn_multi_input.conf index 968621cab59be9296ae5ee962a3a359fff59e022..51881e21d971bbebeceeab1a7c4954e50e3a5e60 100644 --- a/paddle/gserver/tests/sequence_rnn_multi_input.conf +++ b/paddle/gserver/tests/sequence_rnn_multi_input.conf @@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import * define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list', test_list=None, module='rnn_data_provider', - obj='process_seq') + obj='process_seq2') settings(batch_size=2, learning_rate=0.01) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index c5723f8574ab3d7a15bfe7c8db8a9d03951f08b1..eab9bf84141a27d957969a22beb70824659888d7 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -307,6 +307,24 @@ TEST(Layer, blockExpandLayer) { } } +TEST(Layer, maxoutLayer) { + TestConfig config; + config.biasSize = 0; + config.layerConfig.set_type("maxout"); + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 4096, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + MaxOutConfig* maxout = input->mutable_maxout_conf(); + + maxout->set_img_size_x(32); + maxout->set_img_size_y(32); + maxout->set_channels(4); + maxout->set_groups(2); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "maxout", 10, false, useGpu); + } +} void testFcLayer(string format, size_t nnz) { TestConfig config; config.biasSize = 4096; diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 78519ce7aa8742192eb15e5c4705572a7df5dbdc..843eabc97d642fcfb5b5862c0a5bef035a7a2ccb 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -583,6 +583,42 @@ void GpuMatrix::colMax(Matrix& max) { max.maxCols(*this); } +void GpuMatrix::colMax(IVector& maxIds, Matrix& maxVal) { + LOG(FATAL) << "Is not supported"; +} + +void GpuMatrix::maxoutForward(Matrix& a, IVector& id, size_t channels, + size_t groups) { + CHECK(dynamic_cast(&a)); + CHECK(dynamic_cast(&id)); + CHECK_EQ(a.getHeight(), getHeight()); + + size_t size = getWidth(); + size_t batchSize = getHeight(); + const real* input = a.getData(); + real* output = getData(); + int* idForGpu = id.getData(); + + hl_maxout_forward(input, output, idForGpu, batchSize, size, + size / channels, groups); +} + +void GpuMatrix::maxoutBackward(Matrix& a, IVector& id, size_t channels, + size_t groups) { + CHECK(dynamic_cast(&a)); + CHECK(dynamic_cast(&id)); + CHECK_EQ(a.getHeight(), getHeight()); + + size_t size = a.getWidth(); + size_t batchSize = getHeight(); + real* input = getData(); + const real* output = a.getData(); + const int* idForGpu = id.getData(); + + hl_maxout_backward(input, output, idForGpu, batchSize, size, + size / channels, groups); +} + /*calulate the error of classification */ void GpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { GpuMatrixPtr output_ptr = std::dynamic_pointer_cast(output); @@ -2748,6 +2784,95 @@ void CpuMatrix::colMax(Matrix& max) { max.maxCols(*this); } +void CpuMatrix::colMax(IVector& maxIds, Matrix& maxVal) { + CHECK(isContiguous()); + CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal"; + size_t numSamples = getWidth(); + size_t beam = maxVal.getHeight(); + CHECK_EQ(maxIds.getSize(), numSamples * beam); + CHECK_EQ(maxVal.getWidth(), numSamples); + + real* a = getData(); + int* s = maxIds.getData(); + real* t = maxVal.getData(); + size_t dim = getHeight(); + for (size_t i = 0; i < numSamples; i++) { + std::vector> vec; + for (size_t j = 0; j < dim; j++) { + vec.push_back(std::pair(a[i + j * numSamples], j)); + } + + std::partial_sort( + vec.begin(), vec.begin() + beam, vec.end(), + [](const std::pair& l, const std::pair& r) { + return l.first > r.first; + }); + for (size_t j = 0; j < beam; j++) { + t[i + j * numSamples] = vec[j].first; + s[i + j * numSamples] = vec[j].second; + } + } +} + +void CpuMatrix::maxoutForward(Matrix& a, IVector& id, size_t channels, + size_t groups) { + CHECK(dynamic_cast(&a)); + CHECK(dynamic_cast(&id)); + CHECK_EQ(a.getHeight(), getHeight()); + + size_t size = getWidth(); + size_t batchSize = getHeight(); + size_t featLen = size / channels; + const real* input = a.getData(); + int* idForCpu = id.getData(); + + MatrixPtr maxInMat, maxOutMat; + Matrix::resizeOrCreate(maxInMat, groups, size, false, false); + Matrix::resizeOrCreate(maxOutMat, 1, size, false, false); + + for (size_t batch_idx = 0; batch_idx < batchSize; ++batch_idx) { + size_t newIndex = batch_idx * size; + IVectorPtr tmpId = IVector::create(idForCpu + newIndex, size, false); + + for (size_t i = 0; i < channels; ++i) { + size_t newFeatLen = i * featLen; + for (size_t j = 0; j < groups; ++j) { + maxInMat->subMatrix(j, j + 1, newFeatLen, newFeatLen + featLen) + ->copyFrom(input + (newIndex + newFeatLen) * groups + j * featLen, + featLen); + } + } + maxInMat->colMax(*tmpId, *maxOutMat); + this->subRowMatrix(batch_idx, batch_idx + 1)->copyFrom(*maxOutMat); + } +} + +void CpuMatrix::maxoutBackward(Matrix& a, IVector& id, size_t channels, + size_t groups) { + CHECK(dynamic_cast(&a)); + CHECK(dynamic_cast(&id)); + CHECK_EQ(a.getHeight(), getHeight()); + + size_t size = a.getWidth(); + size_t batchSize = getHeight(); + size_t featLen = size / channels; + size_t newFeatLen = groups * featLen; + real* inputG = getData(); + const real* outG = a.getData(); + int* idForCpu = id.getData(); + + for (size_t batch_idx = 0; batch_idx < batchSize; ++batch_idx) { + size_t newIndex = batch_idx * size; + int* idData = idForCpu + newIndex; + + for (size_t i = 0; i < size; ++i) { + int gradIdx = + idData[i] * featLen + (i / featLen) * newFeatLen + i % featLen; + (inputG + newIndex * groups)[gradIdx] += (outG + newIndex)[i]; + } + } +} + void CpuMatrix::rowNormalizeL1(Matrix& out) { CHECK(!out.useGpu()); diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 25104fe1c6d70afbf39ab47a17ce0bf21a121427..047c76a8604cc72cbdbcebb2fe333b55a1e1a9a3 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -493,16 +493,40 @@ public: LOG(FATAL) << "Not implemeted"; } + /** + * set the max of each column of this to mat + */ virtual void colMax(Matrix& max) { LOG(FATAL) << "not implemented"; } + /** + * @brief Get the top k elements of each column of this matrix. + * + * The row ids and values of these elements are stored in + * maxIds and max respectively. where k is the size of maxIds. + * And note that the top k elements are not sorted. + */ + virtual void colMax(IVector& maxIds, Matrix& maxVal) { + LOG(FATAL) << "not implemented"; + } + + virtual void maxoutForward(Matrix& a, IVector& id, size_t channels, + size_t groups) { + LOG(FATAL) << "not implemented"; + } + + virtual void maxoutBackward(Matrix& a, IVector& id, size_t channels, + size_t groups) { + LOG(FATAL) << "not implemented"; + } + virtual void rowMaxId(IVector& maxIds) { LOG(FATAL) << "Not implemented"; } /** * @brief Get the top k elements of each row of this matrix. * * The column ids and values of these elements are stored in - * maxIds and max respectively. Note that the top k - * elements are not sorted. + * maxIds and max respectively. where k is the size of maxIds. + * And note that the top k elements are not sorted. */ virtual void rowMax(IVector& maxIds, Matrix& max) { LOG(FATAL) << "Not implemented"; @@ -1085,6 +1109,9 @@ public: void rowMax(Matrix& max); void rowMax(IVector& maxIds, Matrix& max); void colMax(Matrix& max); + void colMax(IVector& maxIds, Matrix& max); + void maxoutForward(Matrix& a, IVector& id, size_t channels, size_t groups); + void maxoutBackward(Matrix& a, IVector& id, size_t channels, size_t groups); void oneHotCrossEntropy(Matrix& output, IVector& label); void oneHotCrossEntropyBp(Matrix& outputV, IVector& label); @@ -1395,6 +1422,9 @@ public: void rowMax(Matrix& max); void rowMax(IVector& maxIds, Matrix& maxVal); void colMax(Matrix& max); + void colMax(IVector& maxIds, Matrix& maxVal); + void maxoutForward(Matrix& a, IVector& id, size_t channels, size_t groups); + void maxoutBackward(Matrix& a, IVector& id, size_t channels, size_t groups); void rowNormalizeL1(Matrix& out); void oneHotCrossEntropy(Matrix& output, IVector& label); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index e1bda79a8acb16ffb9025ff92afa2bb24d76c4fe..ac160479a9dfcd2a6e5787207ef3fb95182a5692 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1999,6 +1999,78 @@ TEST(Matrix, PoolFwdBwd) { } } +void testMaxOutFwdBwd(int numSamples, int imgSizeH, int imgSizeW, + int channels, int groups) { + int inWidth = imgSizeH * imgSizeW * channels; + int outChannels = channels / groups; + int outWidth = imgSizeH * imgSizeW * outChannels; + + // forward + MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true); + + MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true); + MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false); + + IVectorPtr id = CpuIVector::create(numSamples * outWidth, false); + IVectorPtr idGpu = GpuIVector::create(numSamples * outWidth, true); + IVectorPtr idCheck = CpuIVector::create(numSamples * outWidth, false); + + input->randomizeUniform(); + inputGpu->copyFrom(*input); + + target->maxoutForward(*input, *id, outChannels, groups); + targetGpu->maxoutForward(*inputGpu, *idGpu, outChannels, groups); + + // check + targetCheck->copyFrom(*targetGpu); + MatrixCheckErr(*target, *targetCheck); + idCheck->copyFrom(*idGpu); + VectorCheckEqual(*id, *idCheck); + + // backward + MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true); + + MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpuGrad = GpuMatrix::create(numSamples, outWidth, false, + true); + MatrixPtr targetCheckGrad = CpuMatrix::create(numSamples, inWidth, false, + false); + + inputGrad->randomizeUniform(); + targetGrad->randomizeUniform(); + inputGpuGrad->copyFrom(*inputGrad); + targetGpuGrad->copyFrom(*targetGrad); + + inputGrad->maxoutBackward(*targetGrad, *id, outChannels, groups); + inputGpuGrad->maxoutBackward(*targetGpuGrad, *idGpu, outChannels, groups); + + // check + targetCheckGrad->copyFrom(*inputGpuGrad); + MatrixCheckErr(*inputGrad, *targetCheckGrad); +} + +TEST(Matrix, MaxOutFwdBwd) { + for (auto numSamples : {5, 10}) { + for (auto channels : {8, 16}) { + for (auto imgSizeH : {14, 28}) { + for (auto imgSizeW : {16, 30}) { + for (auto groups : {2, 4}) { + VLOG(3) << " numSamples=" << numSamples + << " channels=" << channels + << " imgSizeH=" << imgSizeH + << " imgSizeW=" << imgSizeW + << " groups=" << groups; + testMaxOutFwdBwd(numSamples, imgSizeH, imgSizeW, channels, groups); + } + } + } + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4 index 25e36f9c4c1687aec46ca7202d1ba8a6e0088fec..70c1f8d563238c2033b1992ec23ad5f73684ecbb 100644 --- a/proto/ModelConfig.proto.m4 +++ b/proto/ModelConfig.proto.m4 @@ -170,6 +170,15 @@ message BlockExpandConfig { required uint32 img_size_y = 11; } +message MaxOutConfig { + required uint32 channels = 1; + required uint32 groups = 2; + + // The size of input feature map. + required uint32 img_size_x = 3; + required uint32 img_size_y = 4; +} + message ProjectionConfig { required string type = 1; required string name = 2; @@ -225,6 +234,7 @@ message LayerInputConfig { // If the input layer has multi-output. // Set the argument name. optional string input_layer_argument = 9; + optional MaxOutConfig maxout_conf = 10; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c1e74c7a2d8f7448429edcdbc2ec7c32f6cedd57..fe8a5e5d48767eefcc912e7754bb93a93a06fb69 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -469,6 +469,7 @@ class Input(Cfg): pool=None, image=None, block_expand=None, + maxout=None, format=None, nnz=None, is_static=None, @@ -785,6 +786,16 @@ class BlockExpand(Cfg): output_y = 0): self.add_keys(locals()) +@config_class +class MaxOut(Cfg): + def __init__( + self, + channels, + groups, + img_size_x = 0, + img_size_y = 0): + self.add_keys(locals()) + def DataBase(async_load_data=False, constant_slots=None, data_ratio=1, @@ -1082,6 +1093,12 @@ def parse_block_expand(block_expand, input_layer_name, block_expand_conf): int(math.ceil((2 * block_expand.padding_y + block_expand.img_size_y \ - block_expand.block_y) / float(block_expand.stride_y))) +def parse_maxout(maxout, input_layer_name, maxout_conf): + maxout_conf.channels = maxout.channels + maxout_conf.groups = maxout.groups + maxout_conf.img_size_x = maxout.img_size_x + maxout_conf.img_size_y = maxout.img_size_y + # Define an evaluator @config_func def Evaluator( @@ -1705,6 +1722,21 @@ class BlockExpandLayer(LayerBase): self.set_layer_size(block_expand_conf.block_x * block_expand_conf.block_y * block_expand_conf.channels) +@config_layer('maxout') +class MaxOutLayer(LayerBase): + def __init__( + self, + name, + inputs, + **xargs): + super(MaxOutLayer, self).__init__(name, 'maxout', 0, inputs=inputs, **xargs) + input_layer = self.get_input_layer(0) + parse_maxout(self.inputs[0].maxout, + input_layer.name, + self.config.inputs[0].maxout_conf) + maxout_conf = self.config.inputs[0].maxout_conf + self.set_layer_size(g_layer_map[input_layer.name].size / maxout_conf.groups) + # key: cost type # value: cost class g_cost_map = {} diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index d45a9b53dcc94f0cfe3d402795c1d0f889853783..c4e8fe4abc026b83a059e993f5469946b926725d 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -55,7 +55,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", 'multi_binary_label_cross_entropy', 'rank_cost', 'lambda_cost', 'huber_cost', # 'block_expand_layer', # TODO(yuyang18): this layer is not correct - 'out_prod_layer', 'print_layer' + 'maxout_layer', 'out_prod_layer', 'print_layer' ] @@ -110,6 +110,7 @@ class LayerType(object): SLOPE_INTERCEPT_LAYER = "slope_intercept" LINEAR_COMBINATION_LAYER = "convex_comb" BLOCK_EXPAND = "blockexpand" + MAXOUT = "maxout" PRINT_LAYER = "print" @@ -3362,6 +3363,73 @@ def block_expand_layer(input, return LayerOutput(name, LayerType.BLOCK_EXPAND, parents=[input]) +@wrap_name_default() +@layer_support() +def maxout_layer(input, + groups, + num_channels=None, + size_x=None, + size_y=None, + name=None, + layer_attr=None): + """ + A layer to do max out on conv layer output. + - Input: output of a conv layer. + - Output: feature map size same as input. Channel is (input channel) / groups. + + So groups should be larger than 1, and the num of channels should be able + to devided by groups. + + Please refer to Paper: + - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf + - Multi-digit Number Recognition from Street View \ + Imagery using Deep Convolutional Neural Networks: \ + https://arxiv.org/pdf/1312.6082v4.pdf + + The simple usage is: + + .. code-block:: python + + maxout = maxout_layer(input, + num_channels=128, + groups=4) + + :param input: The input layer. + :type input: LayerOutput + :param num_channels: The channel number of input layer. If None will be set + automatically from previous output. + :type num_channels: int|None + :param groups: The group number of input layer. + :type groups: int + :param size_x: conv output width. If None will be set + automatically from previous output. + :type size_x: int|None + :param size_y: conv output height. If None will be set + automatically from previous output. + :type size_y: int|None + :param name: The name of this layer, which can not specify. + :type name: None|basestring. + :param layer_attr: Extra Layer attribute. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + assert input.layer_type == LayerType.CONV_LAYER + assert isinstance(input.activation, LinearActivation) + assert groups > 1 + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + assert num_channels % groups == 0 + Layer(name=name, + inputs=Input(input.name, + maxout=MaxOut(channels=num_channels, + groups=groups)), + type=LayerType.MAXOUT, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput(name, LayerType.MAXOUT, parents=[input]) + + @wrap_name_default() @layer_support() def ctc_layer(input, label, size=None, name=None, norm_by_times=False, diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5 index 96bf3fb2e19d6323ed5822205cc08dda0dee0dfd..88ce5c129e552e12b89040855178db8864f7d559 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/check.md5 +++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5 @@ -12,6 +12,7 @@ a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr 8bb44e1e5072d0c261572307e7672bda test_grumemory_layer.protostr 1f3510672dce7a9ed25317fc58579ac7 test_hsigmoid.protostr d350bd91a0dc13e854b1364c3d9339c6 test_lstmemory_layer.protostr +6fa59551808ee7012bbd24f757e782d2 test_maxout.protostr 251a948ba41c1071afcd3d9cf9c233f7 test_ntm_layers.protostr e6ff04e70aea27c7b06d808cc49c9497 test_print_layer.protostr 2a75dd33b640c49a8821c2da6e574577 test_rnn_group.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index 7cdd682056fd486db2c7274636ba51b1d1e7ba5f..4b1d2d3d41d521e70be3ec77d2b98602f750ddf7 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -8,7 +8,8 @@ configs=(test_fc layer_activations projections test_print_layer test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers util_layers simple_rnn_layers unused_layers test_cost_layers -test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight) +test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight +test_maxout) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py b/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py new file mode 100644 index 0000000000000000000000000000000000000000..079e2cf4c432060ae19d1ad70faa6423b687f99a --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py @@ -0,0 +1,30 @@ +from paddle.trainer_config_helpers import * + +settings( + batch_size=1000, + learning_rate=1e-5 +) + +data = data_layer(name='data', size=2304) + +conv = img_conv_layer(input=data, + filter_size = 3, + num_channels=1, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +maxout = maxout_layer(input=conv, + num_channels=16, + groups=2) + +pool = img_pool_layer(input=maxout, + num_channels=8, + pool_size=2, + stride=2, + pool_type=MaxPooling()) + +fc = fc_layer(input=pool, size=384, bias_attr=False) + +outputs(fc)