diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index d19f4a4bb310a73d896bc8f4179f41b1a5752e54..4bd9d5e7c9e9038d7b20d105829f69f929501179 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -91,6 +91,7 @@ extern void hl_expand_feature2col( * @param[in] paddingH padding height. * @param[in] paddingW padding width. * @param[out] tgtData output data. + * @param[in] tgtStride output data stride. * */ extern void hl_maxpool_forward( @@ -100,7 +101,8 @@ extern void hl_maxpool_forward( const int pooledH, const int pooledW, const int sizeX, const int sizeY, const int strideH, const int strideW, - const int paddingH, const int paddingW, real* tgtData); + const int paddingH, const int paddingW, + real* tgtData, const int tgtStride); /** * @brief Maximum pool backward. @@ -123,6 +125,7 @@ extern void hl_maxpool_forward( * @param[in] paddingH padding height. * @param[in] paddingW padding width. * @param[out] targetGrad output grad. + * @param[in] outStride output grad data stride. * */ extern void hl_maxpool_backward( @@ -135,7 +138,7 @@ extern void hl_maxpool_backward( const int strideH, const int strideW, const int paddingH, const int paddingW, real scaleA, real scaleB, - real* targetGrad); + real* targetGrad, const int outStride); /** * @brief Averge pool forward. @@ -154,6 +157,7 @@ extern void hl_maxpool_backward( * @param[in] paddingH padding height. * @param[in] paddingW padding width. * @param[out] tgtData output data. + * @param[in] tgtStride output data stride. * */ extern void hl_avgpool_forward( @@ -163,7 +167,8 @@ extern void hl_avgpool_forward( const int pooledH, const int pooledW, const int sizeX, const int sizeY, const int strideH, const int strideW, - const int paddingH, const int paddingW, real* tgtData); + const int paddingH, const int paddingW, + real* tgtData, const int tgtStride); /** * @brief Maximum pool backward. @@ -184,6 +189,7 @@ extern void hl_avgpool_forward( * @param[in] scaleA scale. * @param[in] scaleB scale. * @param[out] backGrad output grad. + * @param[in] outStride output grad data stride. * */ extern void hl_avgpool_backward( @@ -195,7 +201,7 @@ extern void hl_avgpool_backward( const int strideH, const int strideW, int paddingH, int paddingW, real scaleA, real scaleB, - real* backGrad); + real* backGrad, const int outStride); /** * @brief Cross-map-respose normalize forward. diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 5f696986e3c8fa19e1f234b03d5ef758c95e3aaf..4342c30376eeb1de1c8968be3f64a0511c3dde07 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -44,7 +44,8 @@ inline void hl_maxpool_forward( const int pooledH, const int pooledW, const int sizeX, const int sizeY, const int strideH, const int strideW, - const int paddingH, const int paddingW, real* tgtData) {} + const int paddingH, const int paddingW, + real* tgtData, const int tgtStride) {} inline void hl_maxpool_backward( const int frameCnt, const real* inputData, @@ -56,7 +57,7 @@ inline void hl_maxpool_backward( const int strideH, const int strideW, const int paddingH, const int paddingW, real scaleA, real scaleB, - real* targetGrad) {} + real* targetGrad, const int outStride) {} inline void hl_avgpool_forward( const int frameCnt, const real* inputData, @@ -65,7 +66,8 @@ inline void hl_avgpool_forward( const int pooledH, const int pooledW, const int sizeX, const int sizeY, const int strideH, const int strideW, - const int paddingH, const int paddingW, real* tgtData) {} + const int paddingH, const int paddingW, + real* tgtData, const int tgtStride) {} inline void hl_avgpool_backward( const int frameCnt, const real* outGrad, @@ -76,7 +78,7 @@ inline void hl_avgpool_backward( const int strideH, const int strideW, int paddingH, int paddingW, real scaleA, real scaleB, - real* backGrad) {} + real* backGrad, const int outStride) {} inline void hl_CMRNorm_forward( size_t frameCnt, const real* in, real* scale, real* out, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index baa2fb0d27d749197c10645ff976851ddc38c84f..fcef6a4436b5cbd52fe28bb9a2338bc254227fd7 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -152,7 +152,7 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData, const int ksizeW, const int ksizeH, const int strideH, const int strideW, const int offsetH, const int offsetW, - real* tgtData) { + real* tgtData, const int tgtStride) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int pw = index % pooledW; @@ -173,7 +173,9 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData, maxval = inputData[h * width + w]; } } - tgtData[index] = maxval; + int tgtIndex = index % (pooledW * pooledH * channels) + + frameNum * tgtStride; + tgtData[tgtIndex] = maxval; } } @@ -184,7 +186,7 @@ void hl_maxpool_forward(const int frameCnt, const real* inputData, const int sizeX, const int sizeY, const int strideH, const int strideW, const int paddingH, const int paddingW, - real* tgtData) { + real* tgtData, const int tgtStride) { int num_kernels = pooledH * pooledW * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; @@ -194,7 +196,7 @@ void hl_maxpool_forward(const int frameCnt, const real* inputData, KeMaxPoolForward<<< grid, threads, 0, STREAM_DEFAULT >>> (num_kernels, inputData, channels, height, width, pooledH, pooledW, sizeX, sizeY, strideH, strideW, - paddingH, paddingW, tgtData); + paddingH, paddingW, tgtData, tgtStride); CHECK_SYNC("hl_maxpool_forward failed"); } @@ -207,7 +209,7 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData, const int strideH, const int strideW, const int padH, const int padW, real scaleA, real scaleB, - real* targetGrad) { + real* targetGrad, const int outStride) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { // find out the local index @@ -223,8 +225,8 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData, int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0; real gradient = 0; real input = inputData[index]; - outData += (frameNum * channels + offsetC) * pooledH * pooledW; - outGrad += (frameNum * channels + offsetC) * pooledH * pooledW; + outData += (frameNum * outStride + offsetC * pooledH * pooledW); + outGrad += (frameNum * outStride + offsetC * pooledH * pooledW); for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { if (input == outData[ph * pooledW + pw]) { @@ -246,7 +248,7 @@ void hl_maxpool_backward(const int frameCnt, const real* inputData, const int strideH, const int strideW, const int paddingH, const int paddingW, real scaleA, real scaleB, - real* targetGrad) { + real* targetGrad, const int outStride) { int num_kernels = height * width * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; @@ -257,7 +259,7 @@ void hl_maxpool_backward(const int frameCnt, const real* inputData, strideH, strideW, paddingH, paddingW, scaleA, scaleB, - targetGrad); + targetGrad, outStride); CHECK_SYNC("hl_maxpool_backward"); } @@ -268,7 +270,7 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData, const int sizeX, const int sizeY, const int strideH, const int strideW, const int padH, const int padW, - real* tgtData) { + real* tgtData, const int tgtStride) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int pw = index % pooledW; @@ -293,7 +295,9 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData, aveval += inputData[h * width + w]; } } - tgtData[index] = aveval / pool_size; + int tgtIndex = index % (pooledW * pooledH * channels) + + frameNum * tgtStride; + tgtData[tgtIndex] = aveval / pool_size; } } @@ -303,14 +307,15 @@ void hl_avgpool_forward(const int frameCnt, const real* inputData, const int pooledH, const int pooledW, const int sizeX, const int sizeY, const int strideH, const int strideW, - const int paddingH, const int paddingW, real* tgtData) { + const int paddingH, const int paddingW, + real* tgtData, const int tgtStride) { int num_kernels = pooledH * pooledW * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; KeAvgPoolForward<<< blocks, 1024, 0, STREAM_DEFAULT >>> (num_kernels, inputData, channels, height, width, pooledH, pooledW, sizeX, sizeY, strideH, strideW, - paddingH, paddingW, tgtData); + paddingH, paddingW, tgtData, tgtStride); CHECK_SYNC("hl_avgpool_forward failed"); } @@ -322,7 +327,7 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad, const int strideH, const int strideW, const int padH, const int padW, real scaleA, real scaleB, - real* tgtGrad) { + real* tgtGrad, const int outStride) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int offsetW = index % width + padW; @@ -335,7 +340,8 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad, int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0; int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0; real gradient = 0; - outGrad += (frameNum * channels + offsetC) * pooledH * pooledW; + outGrad += (frameNum * outStride + offsetC * pooledH * pooledW); + for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { @@ -360,7 +366,7 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, const int strideH, const int strideW, const int paddingH, const int paddingW, real scaleA, real scaleB, - real* backGrad) { + real* backGrad, const int outStride) { int num_kernels = height * width * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; @@ -370,7 +376,7 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, strideH, strideW, paddingH, paddingW, scaleA, scaleB, - backGrad); + backGrad, outStride); CHECK_SYNC("hl_avgpool_backward failed"); } diff --git a/paddle/gserver/layers/PoolProjection.cpp b/paddle/gserver/layers/PoolProjection.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50059ee04d39b0dc69000ffe5cf421d12fb873f8 --- /dev/null +++ b/paddle/gserver/layers/PoolProjection.cpp @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "PoolProjection.h" + +namespace paddle { + +REGISTER_PROJECTION_CREATE_FUNC(pool2, &PoolProjection::create); + +PoolProjection* PoolProjection::create(const ProjectionConfig& config, + ParameterPtr parameter, bool useGpu) { + const std::string& pool = config.pool_conf().pool_type(); + if (pool == "max") { + return new MaxPoolProjection(config, parameter, useGpu); + } else if (pool == "avg") { + return new AvgPoolProjection(config, parameter, useGpu); + } else { + LOG(FATAL) << "Unknown pool type: " << pool; + return nullptr; + } +} + +void MaxPoolProjection::forward() { + MatrixPtr inputV = in_->value; + MatrixPtr outV = out_->value; + outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, channels_, + sizeX_, sizeY_, strideY_, stride_, + outputY_, outputX_, confPaddingY_, confPadding_); +} + +void MaxPoolProjection::backward(const UpdateCallback& callback) { + (void)callback; + MatrixPtr outGrad = out_->grad; + MatrixPtr inputV = in_->value; + MatrixPtr outV = out_->value; + MatrixPtr inputGrad = in_->grad; + + if (NULL == inputGrad) { + return; + } + inputGrad->maxPoolBackward(*inputV, imgSizeY_, imgSize_, *outGrad, *outV, + sizeX_, sizeY_, + strideY_, stride_, outputY_, outputX_, 1, 1, + confPaddingY_, confPadding_); +} + +void AvgPoolProjection::forward() { + MatrixPtr inputV = in_->value; + MatrixPtr outV = out_->value; + outV->avgPoolForward(*inputV, imgSizeY_, imgSize_, channels_, + sizeX_, sizeY_, strideY_, stride_, + outputY_, outputX_, confPaddingY_, confPadding_); +} + +void AvgPoolProjection::backward(const UpdateCallback& callback) { + (void)callback; + + MatrixPtr outputGrad = out_->grad; + MatrixPtr inputGrad = in_->grad; + + if (NULL == inputGrad) { + return; + } + + inputGrad->avgPoolBackward(*outputGrad, imgSizeY_, imgSize_, + sizeX_, sizeY_, strideY_, stride_, + outputY_, outputX_, 1, 1, + confPaddingY_, confPadding_); +} +} // namespace paddle diff --git a/paddle/gserver/layers/PoolProjection.h b/paddle/gserver/layers/PoolProjection.h new file mode 100644 index 0000000000000000000000000000000000000000..73d8a41aefabe625484b6d49f8e9cd9d23460973 --- /dev/null +++ b/paddle/gserver/layers/PoolProjection.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Projection.h" + +namespace paddle { + +class PoolProjection : public Projection { +protected: + size_t imgSizeY_, imgSize_; + size_t outputY_, outputX_; + size_t strideY_, stride_; + size_t sizeY_, sizeX_; + int confPaddingY_, confPadding_; + size_t channels_; + std::string poolType_; + +public: + PoolProjection(const ProjectionConfig& config, ParameterPtr parameter, + bool useGpu) + : Projection(config, parameter, useGpu) { + const PoolConfig& conf = config_.pool_conf(); + poolType_ = conf.pool_type(); + channels_ = conf.channels(); + sizeX_ = conf.size_x(); + stride_ = conf.stride(); + outputX_ = conf.output_x(); + imgSize_ = conf.img_size(); + confPadding_ = conf.padding(); + + sizeY_ = conf.has_size_y() ? conf.size_y() : conf.size_x(); + imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size(); + strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); + confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); + outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); + } + static PoolProjection* create(const ProjectionConfig& config, + ParameterPtr parameter, bool useGpu); + const std::string& getPoolType() const { return poolType_; } +}; + +class MaxPoolProjection : public PoolProjection { +public: + MaxPoolProjection(const ProjectionConfig& config, ParameterPtr parameter, + bool useGpu) + : PoolProjection(config, parameter, useGpu) {} + virtual void forward(); + virtual void backward(const UpdateCallback& callback = nullptr); +}; + +class AvgPoolProjection : public PoolProjection { +public: + AvgPoolProjection(const ProjectionConfig& config, ParameterPtr parameter, + bool useGpu) + : PoolProjection(config, parameter, useGpu) {} + virtual void forward(); + virtual void backward(const UpdateCallback& callback = nullptr); +}; +} // namespace paddle diff --git a/paddle/gserver/layers/Projection.h b/paddle/gserver/layers/Projection.h index 3fa3a0cc230ac4c8616abe0eb2c8ac41bde52d53..203edc5396a53cf72dcad6308335ba4731ba49bc 100644 --- a/paddle/gserver/layers/Projection.h +++ b/paddle/gserver/layers/Projection.h @@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once -#include "paddle/parameter/Parameter.h" -#include "ModelConfig.pb.h" #include "Layer.h" +#include "ModelConfig.pb.h" +#include "paddle/parameter/Parameter.h" namespace paddle { @@ -28,6 +27,11 @@ namespace paddle { Projection::registrar_.registerClass<__class_name>(#__type_name); \ }) +#define REGISTER_PROJECTION_CREATE_FUNC(__type_name, createFunction) \ + static InitFunction __reg_type_##__type_name([]() { \ + Projection::registrar_.registerClass(#__type_name, createFunction); \ + }) + /** * A projection takes one Argument as input, calculate the result and add it * to output Argument. @@ -50,7 +54,8 @@ public: registrar_; /** - * Forward propagation. If backward() will be called, in and out must be kept valid until then. + * Forward propagation. If backward() will be called, in and out must be kept + * valid until then. * @param in input of projection * @param out output of projection * @param passType PASS_TRAIN of PASS_TEST diff --git a/paddle/gserver/layers/SpatialPyramidPoolLayer.cpp b/paddle/gserver/layers/SpatialPyramidPoolLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bcdba5c15117546358a66136ca93e4b8915d0cd9 --- /dev/null +++ b/paddle/gserver/layers/SpatialPyramidPoolLayer.cpp @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "SpatialPyramidPoolLayer.h" + +namespace paddle { + +REGISTER_LAYER(spp, SpatialPyramidPoolLayer); + +ProjectionConfig SpatialPyramidPoolLayer::getConfig(size_t imgSizeW, + size_t imgSizeH, + size_t channels, + size_t pyramidLevel, + std::string& poolType) { + ProjectionConfig config; + config.set_type("pool2"); + PoolConfig* conf = config.mutable_pool_conf(); + conf->set_channels(channels); + conf->set_img_size(imgSizeW); + conf->set_img_size_y(imgSizeH); + conf->set_pool_type(poolType); + + int numBins = std::pow(2, pyramidLevel); + + int sizeH = std::ceil(imgSizeH / static_cast(numBins)); + int remainderH = sizeH * numBins - imgSizeH; + int paddingH = (remainderH + 1) / 2; + int outSizeH = outputSize(imgSizeH, sizeH, paddingH, sizeH); + + int sizeW = std::ceil(imgSizeW / static_cast(numBins)); + int remainderW = sizeW * numBins - imgSizeW; + int paddingW = (remainderW + 1) / 2; + int outSizeW = outputSize(imgSizeW, sizeW, paddingW, sizeW); + + conf->set_stride(sizeW); + conf->set_stride_y(sizeH); + conf->set_size_x(sizeW); + conf->set_size_y(sizeH); + conf->set_padding(paddingW); + conf->set_padding_y(paddingH); + conf->set_output_x(outSizeW); + conf->set_output_y(outSizeH); + config.set_output_size(outSizeH * outSizeW * channels); + return config; +} + +void SpatialPyramidPoolLayer::splitInput(Argument& input, size_t height, + size_t width, bool useGpu) { + input.value = getInput(0).value; + if (passType_ != PASS_TEST && needGradient()) { + Matrix::resizeOrCreate(input.grad, height, width, /* trans */ false, + useGpu); + input.grad->zeroMem(); + } +} + +bool SpatialPyramidPoolLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(config_.inputs_size(), 1); + + const SppConfig& sppConf = config_.inputs(0).spp_conf(); + pyramidHeight_ = sppConf.pyramid_height(); + poolType_ = sppConf.pool_type(); + + channels_ = sppConf.channels(); + imgSizeW_ = sppConf.img_size(); + imgSizeH_ = sppConf.has_img_size_y() ? sppConf.img_size_y() : imgSizeW_; + poolProjections_.reserve(pyramidHeight_); + projCol_.reserve(pyramidHeight_); + projInput_.reserve(pyramidHeight_); + projOutput_.resize(pyramidHeight_); + + size_t startCol = 0; + size_t endCol = 0; + for (size_t i = 0; i < pyramidHeight_; i++) { + poolProjections_.emplace_back(PoolProjection::create( + getConfig(imgSizeW_, imgSizeH_, channels_, i, poolType_), + nullptr, useGpu_)); + endCol += poolProjections_[i]->getOutputSize(); + projCol_.push_back(std::make_pair(startCol, endCol)); + startCol = endCol; + projInput_.emplace_back(Argument()); + } + outputSize_ = endCol; + return true; +} + +void SpatialPyramidPoolLayer::forward(PassType passType) { + Layer::forward(passType); + + int batchSize = getInput(0).getBatchSize(); + resetOutput(batchSize, outputSize_); + for (size_t i = 0; i < pyramidHeight_; i++) { + size_t startCol = projCol_[i].first; + size_t endCol = projCol_[i].second; + projOutput_[i].value = output_.value->subColMatrix(startCol, endCol); + projOutput_[i].grad = output_.grad->subColMatrix(startCol, endCol); + splitInput(projInput_[i], getInput(0).value->getHeight(), + getInput(0).value->getWidth(), useGpu_); + } + for (size_t i = 0; i < pyramidHeight_; i++) { + poolProjections_[i]->forward(&projInput_[i], &projOutput_[i], passType); + } +} + +void SpatialPyramidPoolLayer::backward(const UpdateCallback& callback) { + for (size_t i = 0; i < pyramidHeight_; i++) { + if (poolProjections_[i]) { + poolProjections_[i]->backward(callback); + getInput(0).grad->add(*projInput_[i].grad); + } + } +} + +} // namespace paddle + diff --git a/paddle/gserver/layers/SpatialPyramidPoolLayer.h b/paddle/gserver/layers/SpatialPyramidPoolLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..de1fd4da07dd897931516b23f45feef981fc450e --- /dev/null +++ b/paddle/gserver/layers/SpatialPyramidPoolLayer.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#pragma once + +#include "Layer.h" +#include "PoolProjection.h" +#include "paddle/utils/Logging.h" + +namespace paddle { + +class SpatialPyramidPoolLayer : public Layer { +protected: + size_t channels_; + size_t imgSizeW_; + size_t imgSizeH_; + size_t pyramidHeight_; + size_t outputSize_; + std::string poolType_; + + std::vector> poolProjections_; + std::vector projInput_; + std::vector projOutput_; + std::vector> projCol_; + +public: + explicit SpatialPyramidPoolLayer(const LayerConfig& config) : Layer(config) {} + ~SpatialPyramidPoolLayer() {} + + virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + ProjectionConfig getConfig(size_t sizeX_, size_t sizeY_, size_t channels, + size_t pyamidLevel_, std::string& poolType_); + + int outputSize(int imageSize, int windowSize, int padding, int stride) { + return (imageSize - windowSize + 2 * padding) / stride + 1; + } + + virtual void forward(PassType passType); + virtual void backward(const UpdateCallback& callback = nullptr); + void splitInput(Argument& input, size_t height, size_t width, bool useGpu); +}; +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index eab9bf84141a27d957969a22beb70824659888d7..3d633f4b72797192c6fe92c3ad5b935c2d92b2c1 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include #include -#include "paddle/gserver/layers/DataLayer.h" +#include #include "ModelConfig.pb.h" +#include "paddle/gserver/layers/DataLayer.h" #include "paddle/trainer/Trainer.h" -#include "TestUtil.h" #include "LayerGradUtil.h" +#include "TestUtil.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -880,6 +880,32 @@ TEST(Layer, PoolLayer) { #endif } +void testSppLayer(const string& poolType, const int pyramidHeight, bool trans, + bool useGpu) { + TestConfig config; + config.layerConfig.set_type("spp"); + config.inputDefs.push_back({INPUT_DATA, "layer_0", 3200, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + SppConfig* sppConfig = input->mutable_spp_conf(); + sppConfig->set_pool_type(poolType); + sppConfig->set_pyramid_height(pyramidHeight); + sppConfig->set_channels(16); + sppConfig->set_img_size(10); + sppConfig->set_img_size_y(20); + testLayerGrad(config, "spp", 100, trans, useGpu); +} + +TEST(Layer, SpatialPyramidPoolLayer) { + for (auto useGpu : {false, true}) { + testSppLayer("avg", 1, false, useGpu); + testSppLayer("avg", 3, false, useGpu); + testSppLayer("avg", 5, false, useGpu); + testSppLayer("max", 1, false, useGpu); + testSppLayer("max", 3, false, useGpu); + testSppLayer("avg", 5, false, useGpu); + } +} + TEST(Layer, rankCostLayer) { TestConfig config; config.layerConfig.set_type("rank-cost"); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 843eabc97d642fcfb5b5862c0a5bef035a7a2ccb..ddf99f6f2974ca0cc079b0eadac30b9a5039bfd6 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "Matrix.h" +#include "MathFunctions.h" #include "SparseMatrix.h" #include "SparseRowMatrix.h" -#include "MathFunctions.h" -#include #include #include +#include -#include "paddle/utils/Logging.h" #include #include "hl_gpu.h" #include "hl_table_apply.h" #include "hl_top_k.h" +#include "paddle/utils/Logging.h" #include "paddle/utils/ThreadLocal.h" @@ -42,9 +42,9 @@ inline real _safelog(real a) { return a > 0.0f ? std::log(a) : -40.0f; } Matrix::Matrix(MemoryHandlePtr memHandle, size_t height, size_t width, bool trans, bool use_gpu) : BaseMatrix( - height, width, - memHandle ? (reinterpret_cast(memHandle->getBuf())) : nullptr, - trans, use_gpu) { + height, width, + memHandle ? (reinterpret_cast(memHandle->getBuf())) : nullptr, + trans, use_gpu) { elementCnt_ = width * height; memoryHandle_ = memHandle; } @@ -95,7 +95,7 @@ MatrixPtr Matrix::create(MemoryHandlePtr memHandle, size_t height, size_t width, if (auto gpuHandle = std::dynamic_pointer_cast(memHandle)) { return std::make_shared(gpuHandle, height, width, trans); } else if (auto cpuHandle = - std::dynamic_pointer_cast(memHandle)) { + std::dynamic_pointer_cast(memHandle)) { return std::make_shared(cpuHandle, height, width, trans); } else { LOG(FATAL) << "Wrong"; @@ -343,19 +343,17 @@ void GpuMatrix::addBias(Matrix& b, real scale) { void GpuMatrix::collectBias(Matrix& a, real scale) { CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(width_, a.getWidth()); - GpuSparseMatrix* sMatPtr = dynamic_cast(&a); + GpuSparseMatrix* sMatPtr = dynamic_cast(&a); if (!sMatPtr) { sumCols(a, scale); } else { real* data = getData(); hl_sparse_matrix_s A_d = sMatPtr->sMatrix_.get(); - hl_sparse_matrix_column_sum(data, A_d, sMatPtr->getHeight(), - width_, scale); + hl_sparse_matrix_column_sum(data, A_d, sMatPtr->getHeight(), width_, scale); } } -void GpuMatrix::sequenceAvgForward(Matrix& a, - const IVector& startsPos, +void GpuMatrix::sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { size_t height = getHeight(); size_t width = getWidth(); @@ -401,8 +399,8 @@ void GpuMatrix::mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, hl_trans_op_t transa = !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T; hl_trans_op_t transb = !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T; - hl_matrix_mul(A_d, transa, B_d, transb, C_d, dimM, dimN, dimK, - scaleAB, scaleT, lda, ldb, ldc); + hl_matrix_mul(A_d, transa, B_d, transb, C_d, dimM, dimN, dimK, scaleAB, + scaleT, lda, ldb, ldc); } void GpuMatrix::mul(const GpuSparseMatrix& a, const GpuMatrix& b, real scaleAB, @@ -423,8 +421,8 @@ void GpuMatrix::mul(const GpuSparseMatrix& a, const GpuMatrix& b, real scaleAB, hl_sparse_matrix_s A_d = a.sMatrix_.get(); real* B_d = b.data_; real* C_d = data_; - hl_matrix_csr_mul_dense(A_d, transA, B_d, HPPL_OP_N, C_d, height_, - width_, b.height_, scaleAB, scaleT); + hl_matrix_csr_mul_dense(A_d, transA, B_d, HPPL_OP_N, C_d, height_, width_, + b.height_, scaleAB, scaleT); } void GpuMatrix::mul(const GpuMatrix& a, const GpuSparseMatrix& b, real scaleAB, @@ -445,11 +443,11 @@ void GpuMatrix::mul(const GpuMatrix& a, const GpuSparseMatrix& b, real scaleAB, << "Matrix dimensions are not equal"; } if (b.format_ == SPARSE_CSC) { - hl_matrix_dense_mul_csc(A_d, HPPL_OP_N, B_d, transB, C_d, height_, - width_, a.width_, scaleAB, scaleT); + hl_matrix_dense_mul_csc(A_d, HPPL_OP_N, B_d, transB, C_d, height_, width_, + a.width_, scaleAB, scaleT); } else { - hl_matrix_dense_mul_csr(A_d, HPPL_OP_N, B_d, transB, C_d, height_, - width_, a.width_, scaleAB, scaleT); + hl_matrix_dense_mul_csr(A_d, HPPL_OP_N, B_d, transB, C_d, height_, width_, + a.width_, scaleAB, scaleT); } } @@ -511,8 +509,8 @@ void GpuMatrix::selectRows(Matrix& table, IVector& ids) { size_t tableSize = table.getHeight(); int* index = ids.getData(); - hl_matrix_select_rows(a, stride_, table.getData(), table.stride_, - index, numSamples, tableSize, dim); + hl_matrix_select_rows(a, stride_, table.getData(), table.stride_, index, + numSamples, tableSize, dim); #endif } @@ -529,8 +527,8 @@ void GpuMatrix::addToRows(Matrix& table, IVector& ids) { size_t tableSize = table.getHeight(); int* index = ids.getData(); - hl_matrix_add_to_rows(table.getData(), table.stride_, a, stride_, - index, numSamples, tableSize, dim); + hl_matrix_add_to_rows(table.getData(), table.stride_, a, stride_, index, + numSamples, tableSize, dim); #endif } @@ -565,13 +563,8 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxVal.getHeight(), numSamples); - hl_matrix_top_k(maxVal.getData(), - maxVal.getStride(), - maxIds.getData(), - this->getData(), - this->getStride(), - this->getWidth(), - beam, + hl_matrix_top_k(maxVal.getData(), maxVal.getStride(), maxIds.getData(), + this->getData(), this->getStride(), this->getWidth(), beam, numSamples); #endif } @@ -595,12 +588,12 @@ void GpuMatrix::maxoutForward(Matrix& a, IVector& id, size_t channels, size_t size = getWidth(); size_t batchSize = getHeight(); - const real* input = a.getData(); + const real* input = a.getData(); real* output = getData(); int* idForGpu = id.getData(); - hl_maxout_forward(input, output, idForGpu, batchSize, size, - size / channels, groups); + hl_maxout_forward(input, output, idForGpu, batchSize, size, size / channels, + groups); } void GpuMatrix::maxoutBackward(Matrix& a, IVector& id, size_t channels, @@ -611,12 +604,12 @@ void GpuMatrix::maxoutBackward(Matrix& a, IVector& id, size_t channels, size_t size = a.getWidth(); size_t batchSize = getHeight(); - real* input = getData(); + real* input = getData(); const real* output = a.getData(); const int* idForGpu = id.getData(); - hl_maxout_backward(input, output, idForGpu, batchSize, size, - size / channels, groups); + hl_maxout_backward(input, output, idForGpu, batchSize, size, size / channels, + groups); } /*calulate the error of classification */ @@ -632,8 +625,8 @@ void GpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { real* recResult_d = data_; int* label_d = label_ptr->getData(); - hl_matrix_classification_error(output_d, label_d, recResult_d, - height_, output_ptr->width_); + hl_matrix_classification_error(output_d, label_d, recResult_d, height_, + output_ptr->width_); } /* copy -log(output[i * width + label]) to this->data[i] */ @@ -702,8 +695,7 @@ void GpuMatrix::sequenceSoftmax(Matrix& output, const IVector& index) { real* outputData = output.getData(); auto starts = index.getData(); int numSequences = index.getSize() - 1; - hl_sequence_softmax_forward(inputData, outputData, - starts, numSequences); + hl_sequence_softmax_forward(inputData, outputData, starts, numSequences); } void GpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) { @@ -717,8 +709,7 @@ void GpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) { real* output_d = output.data_; real* sftmaxSum_d = sftmaxSum.data_; real* grad_d = data_; - hl_matrix_softmax_derivative(grad_d, output_d, sftmaxSum_d, height_, - width_); + hl_matrix_softmax_derivative(grad_d, output_d, sftmaxSum_d, height_, width_); } void GpuMatrix::softmaxBackward(Matrix& outputV) { @@ -769,7 +760,7 @@ void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) { } void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) { CHECK(output1.useGpu_ == true && output2.useGpu_ == true) - << "Matrix type are not equal"; + << "Matrix type are not equal"; size_t numSamples = getHeight(); size_t dim = output1.getWidth(); CHECK_EQ(getWidth(), 1UL); @@ -778,15 +769,15 @@ void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) { real* out = getData(); real* x = output1.getData(); real* y = output2.getData(); - hl_cossim(out, x, y, - dim, output1.getHeight(), output2.getHeight(), scale); + hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale); } void GpuMatrix::cosSimDerivative(Matrix& output, Matrix& prevOut1, Matrix& prevOut2, Matrix& prevGrad1, Matrix& prevGrad2, real scale) { CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true && prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true && - prevGrad2.useGpu_ == true) << "Matrix type are not equal"; + prevGrad2.useGpu_ == true) + << "Matrix type are not equal"; CHECK_EQ(getWidth(), 1UL); CHECK_EQ(output.getWidth(), 1UL); @@ -806,9 +797,8 @@ void GpuMatrix::cosSimDerivative(Matrix& output, Matrix& prevOut1, real* prevOutY = prevOut2.getData(); real* prevGradX = prevGrad1.getData(); real* prevGradY = prevGrad2.getData(); - hl_cossim_derivative(grad, out, prevOutX, prevOutY, - prevGradX, prevGradY, dim, - prevOut1.getHeight(), prevOut2.getHeight(), scale); + hl_cossim_derivative(grad, out, prevOutX, prevOutY, prevGradX, prevGradY, dim, + prevOut1.getHeight(), prevOut2.getHeight(), scale); } void GpuMatrix::randomizeUniform() { @@ -859,8 +849,8 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) { void GpuMatrix::convExpand(Matrix& feature, int feaImgHeight, int feaImgWidth, int channels, int blockH, int blockW, int strideH, - int strideW, int paddingH, int paddingW, - int outputH, int outputW) { + int strideW, int paddingH, int paddingW, int outputH, + int outputW) { CHECK(feature.useGpu_ == true) << "Matrix type are not equal"; CHECK_EQ(size_t(feaImgHeight * feaImgWidth * channels), @@ -870,17 +860,16 @@ void GpuMatrix::convExpand(Matrix& feature, int feaImgHeight, int feaImgWidth, size_t elemCnt = outputH * outputW * blockH * blockW * channels; CHECK_EQ(elemCnt, height_ * width_) << "Matrix dimensions are not equal"; - hl_expand_feature2col(feature.getData(), channels, feaImgHeight, - feaImgWidth, blockH, blockW, strideH, strideW, - paddingH, paddingW, outputH, outputW, - getData()); + hl_expand_feature2col(feature.getData(), channels, feaImgHeight, feaImgWidth, + blockH, blockW, strideH, strideW, paddingH, paddingW, + outputH, outputW, getData()); } void GpuMatrix::convShrink(Matrix& expandFeat, int thisImgHeight, int thisImgWidth, int channels, int blockH, int blockW, int strideH, int strideW, int paddingH, - int paddingW, int outputH, int outputW, - real alpha, real beta) { + int paddingW, int outputH, int outputW, real alpha, + real beta) { CHECK(expandFeat.useGpu_ == true) << "Matrix type are not equal"; CHECK_EQ(size_t(thisImgHeight * thisImgWidth * channels), getHeight() * getWidth()) @@ -889,18 +878,17 @@ void GpuMatrix::convShrink(Matrix& expandFeat, int thisImgHeight, size_t elemCnt = outputH * outputW * blockW * blockH * channels; CHECK(elemCnt == expandFeat.getHeight() * expandFeat.getWidth()) << "Matrix dimensions are not equal"; - hl_shrink_col2feature( - expandFeat.getData(), channels, thisImgHeight, thisImgWidth, blockH, - blockW, strideH, strideW, paddingH, paddingW, outputH, outputW, - getData(), alpha, beta); + hl_shrink_col2feature(expandFeat.getData(), channels, thisImgHeight, + thisImgWidth, blockH, blockW, strideH, strideW, + paddingH, paddingW, outputH, outputW, getData(), alpha, + beta); } void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, - size_t imgSizeW, size_t channels, - size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, - size_t paddingH, size_t paddingW) { + size_t imgSizeW, size_t channels, size_t sizeX, + size_t sizeY, size_t strideH, size_t strideW, + size_t outputH, size_t outputW, size_t paddingH, + size_t paddingW) { CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; real* inputData = inputMat.getData(); @@ -911,16 +899,15 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); - hl_maxpool_forward(frameNum, inputData, channels, height, width, - outputH, outputW, sizeX, sizeY, strideH, strideW, - paddingH, paddingW, data_); + hl_maxpool_forward(frameNum, inputData, channels, height, width, outputH, + outputW, sizeX, sizeY, strideH, strideW, paddingH, + paddingW, data_, getStride()); } void GpuMatrix::maxPoolBackward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, Matrix& outGrad, Matrix& outV, - size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, + size_t sizeX, size_t sizeY, size_t strideH, + size_t strideW, size_t outputH, size_t outputW, real scaleTargets, real scaleOutput, size_t paddingH, size_t paddingW) { CHECK(inputMat.useGpu_ == true && outGrad.useGpu_ == true && @@ -940,19 +927,17 @@ void GpuMatrix::maxPoolBackward(Matrix& inputMat, size_t imgSizeH, CHECK(outGrad.getHeight() == outV.getHeight() && outGrad.getWidth() == outV.getWidth()); - - hl_maxpool_backward(frameNum, inputData, outData, outDiff, channels, - height, width, outputH, outputW, sizeX, sizeY, - strideH, strideW, paddingH, paddingW, - scaleTargets, scaleOutput, data_); + hl_maxpool_backward(frameNum, inputData, outData, outDiff, channels, height, + width, outputH, outputW, sizeX, sizeY, strideH, strideW, + paddingH, paddingW, scaleTargets, scaleOutput, data_, + outGrad.getStride()); } void GpuMatrix::avgPoolForward(Matrix& inputMat, size_t imgSizeH, - size_t imgSizeW, size_t channels, - size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, - size_t paddingH, size_t paddingW) { + size_t imgSizeW, size_t channels, size_t sizeX, + size_t sizeY, size_t strideH, size_t strideW, + size_t outputH, size_t outputW, size_t paddingH, + size_t paddingW) { CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; real* inputData = inputMat.getData(); @@ -963,18 +948,17 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat, size_t imgSizeH, CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); - hl_avgpool_forward(frameNum, inputData, channels, height, width, - outputH, outputW, sizeX, sizeY, - strideH, strideW, - paddingH, paddingW, data_); + hl_avgpool_forward(frameNum, inputData, channels, height, width, outputH, + outputW, sizeX, sizeY, strideH, strideW, paddingH, + paddingW, data_, getStride()); } void GpuMatrix::avgPoolBackward(Matrix& outGrad, size_t imgSizeH, size_t imgSizeW, size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, - real scaleTargets, real scaleOutput, - size_t paddingH, size_t paddingW) { + size_t strideH, size_t strideW, size_t outputH, + size_t outputW, real scaleTargets, + real scaleOutput, size_t paddingH, + size_t paddingW) { CHECK(outGrad.useGpu_ == true) << "Matrix type are not equal"; real* outDiff = outGrad.getData(); @@ -986,11 +970,10 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, size_t imgSizeH, CHECK(height_ == outGrad.getHeight()); CHECK(outGrad.getWidth() == outputH * outputW * channels); - hl_avgpool_backward(frameNum, outDiff, channels, height, width, - outputH, outputW, sizeX, sizeY, - strideH, strideW, paddingH, paddingW, - scaleTargets, scaleOutput, - data_); + hl_avgpool_backward(frameNum, outDiff, channels, height, width, outputH, + outputW, sizeX, sizeY, strideH, strideW, paddingH, + paddingW, scaleTargets, scaleOutput, data_, + outGrad.getStride()); } void GpuMatrix::crossMapNormalFwd(Matrix& input, size_t imgSizeH, @@ -1005,8 +988,8 @@ void GpuMatrix::crossMapNormalFwd(Matrix& input, size_t imgSizeH, CHECK(denoms.getHeight() == input.getHeight() && denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && input.getWidth() == width_); - hl_CMRNorm_forward(num, input.getData(), denoms.getData(), data_, - channels, height, width, sizeX, scale, -pow); + hl_CMRNorm_forward(num, input.getData(), denoms.getData(), data_, channels, + height, width, sizeX, scale, -pow); } void GpuMatrix::crossMapNormalBwd(Matrix& localGrad, Matrix& denoms, @@ -1026,13 +1009,11 @@ void GpuMatrix::crossMapNormalBwd(Matrix& localGrad, Matrix& denoms, denoms.getWidth() == localGrad.getWidth()); hl_CMRNorm_backward(num, preOutV.getData(), denoms.getData(), - localOutV.getData(), localGrad.getData(), data_, - channels, height, width, sizeX, -pow, - 2.0f * pow * scale); + localOutV.getData(), localGrad.getData(), data_, channels, + height, width, sizeX, -pow, 2.0f * pow * scale); } -void GpuMatrix::maxSequenceForward(Matrix& input, - const IVector& sequence, +void GpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { CHECK(dynamic_cast(&input)); CHECK(dynamic_cast(&sequence)); @@ -1049,12 +1030,11 @@ void GpuMatrix::maxSequenceForward(Matrix& input, CHECK_EQ(numSequences, sequence.getSize() - 1); CHECK_EQ(numSequences * dim, index.getSize()); - hl_max_sequence_forward(inputData, starts, outData, maxIndex, - numSequences, dim); + hl_max_sequence_forward(inputData, starts, outData, maxIndex, numSequences, + dim); } -void GpuMatrix::maxSequenceBackward(Matrix& outputGrad, - const IVector& sequence, +void GpuMatrix::maxSequenceBackward(Matrix& outputGrad, const IVector& sequence, IVector& index) { CHECK(dynamic_cast(&outputGrad)); CHECK(dynamic_cast(&sequence)); @@ -1111,9 +1091,8 @@ void GpuMatrix::contextProjectionBackwardData(MatrixPtr inputGrad, real* inGrad = inputGrad->getData(); const int* starts = sequence.getData(); - hl_context_projection_backward_data(outGrad, starts, inGrad, - numSequences, inputDim, - contextLength, contextStart); + hl_context_projection_backward_data(outGrad, starts, inGrad, numSequences, + inputDim, contextLength, contextStart); } void GpuMatrix::contextProjectionBackwardWeight(MatrixPtr weightGrad, @@ -1133,9 +1112,9 @@ void GpuMatrix::contextProjectionBackwardWeight(MatrixPtr weightGrad, real* wtGrad = weightGrad->getData(); const int* starts = sequence.getData(); - hl_context_projection_backward_weight( - outGrad, starts, wtGrad, numSequences, weightDim, totalPad, contextLength, - contextStart, beginPad); + hl_context_projection_backward_weight(outGrad, starts, wtGrad, numSequences, + weightDim, totalPad, contextLength, + contextStart, beginPad); } void GpuMatrix::paramReluForward(Matrix& data, Matrix& W) { @@ -1147,8 +1126,7 @@ void GpuMatrix::paramReluForward(Matrix& data, Matrix& W) { size_t numSamples = data.getHeight(); size_t partial_sum = numElements / (W.getHeight() * W.getWidth()); real* output = getData(); - hl_param_relu_forward(output, input, w, numElements, numSamples, - partial_sum); + hl_param_relu_forward(output, input, w, numElements, numSamples, partial_sum); } void GpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) { @@ -1160,8 +1138,8 @@ void GpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) { size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); size_t partial_sum = numElements / (this->getHeight() * this->getWidth()); - hl_param_relu_backward_w(wgrad, ograd, input, - numElements, numSamples, partial_sum); + hl_param_relu_backward_w(wgrad, ograd, input, numElements, numSamples, + partial_sum); } void GpuMatrix::paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) { @@ -1172,8 +1150,8 @@ void GpuMatrix::paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) { size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); size_t partial_sum = numElements / (W.getHeight() * W.getWidth()); - hl_param_relu_backward_diff(ograd, input, w, diff, - numElements, numSamples, partial_sum); + hl_param_relu_backward_diff(ograd, input, w, diff, numElements, numSamples, + partial_sum); } void GpuMatrix::addColumnVector(const Matrix& b) { @@ -1422,8 +1400,8 @@ void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { void CpuMatrix::convExpand(Matrix& feature, int feaImgHeight, int feaImgWidth, int channels, int blockH, int blockW, int strideH, - int strideW, int paddingH, int paddingW, - int outputH, int outputW) { + int strideW, int paddingH, int paddingW, int outputH, + int outputW) { CHECK(feature.useGpu_ == false) << "Matrix type are not equal"; CHECK_EQ(size_t(feaImgHeight * feaImgWidth * channels), @@ -1463,8 +1441,8 @@ void CpuMatrix::convExpand(Matrix& feature, int feaImgHeight, int feaImgWidth, void CpuMatrix::convShrink(Matrix& expandFeat, int thisImgHeight, int thisImgWidth, int channels, int blockH, int blockW, int strideH, int strideW, int paddingH, - int paddingW, int outputH, int outputW, - real alpha, real beta) { + int paddingW, int outputH, int outputW, real alpha, + real beta) { CHECK(expandFeat.useGpu_ == false) << "Matrix type are not equal"; CHECK_EQ(size_t(thisImgHeight * thisImgWidth * channels), getHeight() * getWidth()) @@ -1501,11 +1479,10 @@ void CpuMatrix::convShrink(Matrix& expandFeat, int thisImgHeight, } void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, - size_t imgSizeW, size_t channels, - size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, - size_t paddingH, size_t paddingW) { + size_t imgSizeW, size_t channels, size_t sizeX, + size_t sizeY, size_t strideH, size_t strideW, + size_t outputH, size_t outputW, size_t paddingH, + size_t paddingW) { real* inputData = inputMat.getData(); real* outData = data_; size_t num = inputMat.getHeight(); @@ -1513,15 +1490,20 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t inHeight = imgSizeH; CHECK(inHeight * inWidth == inputMat.getWidth() / channels); CHECK_EQ(num, this->getHeight()); - CHECK_EQ(channels*outputH*outputW, this->getWidth()); + CHECK_EQ(channels * outputH * outputW, this->getWidth()); /* initialize the data_ */ - for (size_t i = 0; i < height_ * width_; i++) { - outData[i] = -(real)FLT_MAX; + for (size_t i = 0; i < height_; i++) { + for (size_t j = 0; j < width_; j++) { + outData[i * getStride() + j] = -(real)FLT_MAX; + } } /* pool max one by one */ - for (size_t n = 0; n < num; ++n) { // frame by frame + for (size_t n = 0; n < num; ++n) { // frame by frame + if (!isContiguous()) { + outData = data_ + n * getStride(); + } for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t ph = 0; ph < outputH; ++ph) { for (size_t pw = 0; pw < outputW; ++pw) { @@ -1564,6 +1546,10 @@ void CpuMatrix::maxPoolBackward(Matrix& image, size_t imgSizeH, size_t imgSizeW, real* otData = outV.getData(); real* otGrad = outGrad.getData(); for (size_t n = 0; n < num; ++n) { + if (!outV.isContiguous()) { + otData = outV.getData() + n * outV.getStride(); + otGrad = outGrad.getData() + n * outGrad.getStride(); + } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { for (size_t pw = 0; pw < outputW; ++pw) { @@ -1594,9 +1580,9 @@ void CpuMatrix::maxPoolBackward(Matrix& image, size_t imgSizeH, size_t imgSizeW, void CpuMatrix::avgPoolForward(Matrix& input, size_t imgSizeH, size_t imgSizeW, size_t channels, size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, - size_t paddingH, size_t paddingW) { + size_t strideH, size_t strideW, size_t outputH, + size_t outputW, size_t paddingH, + size_t paddingW) { // The main loop size_t num = input.getHeight(); size_t inHeight = imgSizeH; @@ -1607,6 +1593,9 @@ void CpuMatrix::avgPoolForward(Matrix& input, size_t imgSizeH, size_t imgSizeW, real* inData = input.getData(); for (size_t n = 0; n < num; ++n) { + if (!isContiguous()) { + tgtData = data_ + n * getStride(); + } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { for (size_t pw = 0; pw < outputW; ++pw) { @@ -1638,9 +1627,8 @@ void CpuMatrix::avgPoolForward(Matrix& input, size_t imgSizeH, size_t imgSizeW, } void CpuMatrix::avgPoolBackward(Matrix& input, size_t imgSizeH, size_t imgSizeW, - size_t sizeX, size_t sizeY, - size_t strideH, size_t strideW, - size_t outputH, size_t outputW, + size_t sizeX, size_t sizeY, size_t strideH, + size_t strideW, size_t outputH, size_t outputW, real scaleTargets, real scaleOutput, size_t paddingH, size_t paddingW) { size_t num = input.getHeight(); @@ -1650,6 +1638,9 @@ void CpuMatrix::avgPoolBackward(Matrix& input, size_t imgSizeH, size_t imgSizeW, real* outData = getData(); for (size_t n = 0; n < num; ++n) { + if (!input.isContiguous()) { + inData = input.getData() + n * input.getStride(); + } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { for (size_t pw = 0; pw < outputW; ++pw) { @@ -1752,8 +1743,7 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, Matrix& denoms, * Output: output size is the number of input sequences (NOT input instances). * output[i] is set to max_{for each instance in this sequence}{input[i]} */ -void CpuMatrix::maxSequenceForward(Matrix& input, - const IVector& sequence, +void CpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { CHECK(dynamic_cast(&input)); CHECK(dynamic_cast(&sequence)); @@ -1794,8 +1784,7 @@ void CpuMatrix::maxSequenceForward(Matrix& input, } } -void CpuMatrix::maxSequenceBackward(Matrix& outputGrad, - const IVector& sequence, +void CpuMatrix::maxSequenceBackward(Matrix& outputGrad, const IVector& sequence, IVector& index) { CHECK(dynamic_cast(&outputGrad)); CHECK(dynamic_cast(&sequence)); @@ -2000,8 +1989,7 @@ void CpuMatrix::collectBias(Matrix& a, real scale) { } } -void CpuMatrix::sequenceAvgForward(Matrix& a, - const IVector& startsPos, +void CpuMatrix::sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { size_t height = getHeight(); size_t width = getWidth(); @@ -2592,7 +2580,7 @@ void SharedCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, blockSeq.push_back(k); } std::shuffle(blockSeq.begin(), blockSeq.end(), - ThreadLocalRandomEngine::get()); + ThreadLocalRandomEngine::get()); } std::vector& localBufRows = *localBufRows_; int* cols = a->getCols(); @@ -2823,7 +2811,7 @@ void CpuMatrix::maxoutForward(Matrix& a, IVector& id, size_t channels, size_t size = getWidth(); size_t batchSize = getHeight(); size_t featLen = size / channels; - const real* input = a.getData(); + const real* input = a.getData(); int* idForCpu = id.getData(); MatrixPtr maxInMat, maxOutMat; @@ -2857,8 +2845,8 @@ void CpuMatrix::maxoutBackward(Matrix& a, IVector& id, size_t channels, size_t batchSize = getHeight(); size_t featLen = size / channels; size_t newFeatLen = groups * featLen; - real* inputG = getData(); - const real* outG = a.getData(); + real* inputG = getData(); + const real* outG = a.getData(); int* idForCpu = id.getData(); for (size_t batch_idx = 0; batch_idx < batchSize; ++batch_idx) { @@ -3082,9 +3070,9 @@ void CpuMatrix::sequenceSoftmax(Matrix& output, const IVector& index) { CHECK(isContiguous()); MatrixPtr inTmp = Matrix::create(nullptr, /* height= */ 1, 1, - /* trans= */ false, false); + /* trans= */ false, false); MatrixPtr outTmp = Matrix::create(nullptr, /* height= */ 1, 1, - /* trans= */ false, false); + /* trans= */ false, false); size_t numSequences = index.getSize() - 1; auto starts = index.getData(); for (size_t i = 0; i < numSequences; ++i) { diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4 index 70c1f8d563238c2033b1992ec23ad5f73684ecbb..5dac2f82041905cd8f460a1ffb2356ee4941d5e9 100644 --- a/proto/ModelConfig.proto.m4 +++ b/proto/ModelConfig.proto.m4 @@ -120,6 +120,14 @@ message PoolConfig { optional uint32 padding_y = 13 [default = 0]; } +message SppConfig { + required string pool_type = 1; + required uint32 pyramid_height = 2; + required uint32 channels = 3; + required uint32 img_size = 4; + optional uint32 img_size_y = 5; +} + message NormConfig { // rnorm or cmrnorm required string norm_type = 1; @@ -194,6 +202,9 @@ message ProjectionConfig { optional ConvConfig conv_conf = 8; optional int32 num_filters = 9; + // For pool + optional PoolConfig pool_conf = 10; + // For IdentityOffsetProjection optional uint64 offset = 11 [default = 0]; } @@ -235,6 +246,7 @@ message LayerInputConfig { // Set the argument name. optional string input_layer_argument = 9; optional MaxOutConfig maxout_conf = 10; + optional SppConfig spp_conf = 11; } message LayerConfig {