From 6da00da7b5b2449d6668a84728708e43ec030433 Mon Sep 17 00:00:00 2001 From: xzl Date: Mon, 20 Nov 2017 11:58:42 +0800 Subject: [PATCH] code format check --- paddle/cuda/include/hl_cnn.h | 34 +-- paddle/cuda/include/stub/hl_cnn_stub.h | 36 +-- paddle/gserver/layers/UpsampleLayer.cpp | 1 + paddle/gserver/layers/UpsampleLayer.h | 1 - paddle/math/Matrix.cpp | 220 +++++++++--------- paddle/math/Matrix.h | 72 +++--- python/paddle/trainer/config_parser.py | 8 +- .../paddle/trainer_config_helpers/layers.py | 2 + 8 files changed, 192 insertions(+), 182 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index c8dd3eb91e1..8d0822471b6 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -378,14 +378,15 @@ extern void hl_maxout_backward(real* inGrad, * @param[in] outputW the output widht. * @param[out] outputData output data. */ -extern void hl_upsample_forward(real *inputData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *outputData); +extern void hl_upsample_forward(real* inputData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* outputData); /** * @brief Upsample backward. @@ -399,13 +400,14 @@ extern void hl_upsample_forward(real *inputData, real *maskData, * @param[in] outputW the output widht. * @param[out] inputGradData the input grad data. */ -extern void hl_upsample_backward(real *outputGradData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *inputGradData); +extern void hl_upsample_backward(real* outputGradData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* inputGradData); #endif // HL_CNN_H_ diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index ef1f67980eb..e83db71bb7f 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -222,22 +222,24 @@ inline void hl_maxout_backward(real* inGrad, size_t featLen, size_t group) {} -inline void hl_upsample_forward(real *inputData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *outputData) {} - -inline void hl_upsample_backward(real *outputGradData, real *maskData, - size_t batchSize, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW, - real *inputGradData) {} +inline void hl_upsample_forward(real* inputData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* outputData) {} + +inline void hl_upsample_backward(real* outputGradData, + real* maskData, + size_t batchSize, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW, + real* inputGradData) {} #endif // HL_CNN_STUB_H_ diff --git a/paddle/gserver/layers/UpsampleLayer.cpp b/paddle/gserver/layers/UpsampleLayer.cpp index 300bb82d688..3ff5332e640 100644 --- a/paddle/gserver/layers/UpsampleLayer.cpp +++ b/paddle/gserver/layers/UpsampleLayer.cpp @@ -30,6 +30,7 @@ size_t UpsampleLayer::getOutputSize() { bool UpsampleLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 2U); CHECK_EQ(config_.inputs_size(), 2); const auto& conf = config_.inputs(0).upsample_conf(); diff --git a/paddle/gserver/layers/UpsampleLayer.h b/paddle/gserver/layers/UpsampleLayer.h index 2ae9363439e..25efbac5e9e 100644 --- a/paddle/gserver/layers/UpsampleLayer.h +++ b/paddle/gserver/layers/UpsampleLayer.h @@ -32,7 +32,6 @@ namespace paddle { class UpsampleLayer : public Layer { public: explicit UpsampleLayer(const LayerConfig& config) : Layer(config) {} - ~UpsampleLayer() {} bool init(const LayerMap& layerMap, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 1f6458a2880..ad9a73a2bf8 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1024,61 +1024,63 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) { } void GpuMatrix::upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - CHECK(input.useGpu_ == true) << "Matrix type are not equal"; - CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; - - real *inputData = input.getData(); - real *maskData = mask.getData(); - real *outData = data_; - - size_t batch = input.getHeight(); - - CHECK(imgSizeH * imgSizeW * channels == input.getWidth()); - CHECK(imgSizeH * imgSizeW * channels == mask.getWidth()); - CHECK_EQ(batch, this->getHeight()); - CHECK(width_ == outputH * outputW * channels); - hl_upsample_forward(inputData, maskData, - batch, - imgSizeH, - imgSizeW, - channels, - outputH, - outputW, - outData); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + CHECK(input.useGpu_ == true) << "Matrix type are not equal"; + CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; + + real* inputData = input.getData(); + real* maskData = mask.getData(); + real* outData = data_; + + size_t batch = input.getHeight(); + + CHECK(imgSizeH * imgSizeW * channels == input.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == mask.getWidth()); + CHECK_EQ(batch, this->getHeight()); + CHECK(width_ == outputH * outputW * channels); + hl_upsample_forward(inputData, + maskData, + batch, + imgSizeH, + imgSizeW, + channels, + outputH, + outputW, + outData); } void GpuMatrix::upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - CHECK(outputGrad.useGpu_ == true) << "Matrix type are not equal"; - CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; - - real *outputGradData = outputGrad.getData(); - real *maskData = mask.getData(); - real *inputGradData = data_; - size_t batch = outputGrad.getHeight(); - - CHECK(imgSizeH * imgSizeW == this->getWidth()/channels); - CHECK_EQ(batch, this->getHeight()); - CHECK_EQ(channels * outputH * outputW, outputGrad.getWidth()); - hl_upsample_backward(outputGradData, maskData, - batch, - imgSizeH, - imgSizeW, - channels, - outputH, - outputW, - inputGradData); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + CHECK(outputGrad.useGpu_ == true) << "Matrix type are not equal"; + CHECK(mask.useGpu_ == true) << "Matrix type are not equal"; + + real* outputGradData = outputGrad.getData(); + real* maskData = mask.getData(); + real* inputGradData = data_; + size_t batch = outputGrad.getHeight(); + + CHECK(imgSizeH * imgSizeW == this->getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outputH * outputW, outputGrad.getWidth()); + hl_upsample_backward(outputGradData, + maskData, + batch, + imgSizeH, + imgSizeW, + channels, + outputH, + outputW, + inputGradData); } void GpuMatrix::maxPoolForward(Matrix& inputMat, @@ -2040,71 +2042,69 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) { } void CpuMatrix::upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - real *inputData = input.getData(); - real *maskData = mask.getData(); - real *outData = data_; - size_t inLength = imgSizeH * imgSizeW; - size_t outLength = outputH * outputW; - size_t batch = input.getHeight(); - CHECK(inLength == input.getWidth() / channels); - CHECK_EQ(batch, this->getHeight()); - CHECK_EQ(channels * outLength, this->getWidth()); - - for (size_t k = 0; k < batch; k++) { - for (size_t c = 0; c < channels; c++) { - for (size_t i = 0; i < inLength; i++) { - size_t out_index = static_cast(maskData[i]); - if (out_index >= outLength) { - LOG(FATAL) << "upsample index " << out_index - << " out of range."; - } - outData[out_index] = inputData[i]; - } - inputData += inLength; - maskData += inLength; - outData += outLength; + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + real* inputData = input.getData(); + real* maskData = mask.getData(); + real* outData = data_; + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t batch = input.getHeight(); + CHECK(inLength == input.getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outLength, this->getWidth()); + + for (size_t k = 0; k < batch; k++) { + for (size_t c = 0; c < channels; c++) { + for (size_t i = 0; i < inLength; i++) { + size_t out_index = static_cast(maskData[i]); + if (out_index >= outLength) { + LOG(FATAL) << "upsample index " << out_index << " out of range."; } + outData[out_index] = inputData[i]; + } + inputData += inLength; + maskData += inLength; + outData += outLength; } + } } void CpuMatrix::upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { - real *outputGradData = outputGrad.getData(); - real *maskData = mask.getData(); - real *inputGradData = data_; - size_t inLength = imgSizeH * imgSizeW; - size_t outLength = outputH * outputW; - size_t batch = outputGrad.getHeight(); - CHECK(inLength == this->getWidth()/channels); - CHECK_EQ(batch, this->getHeight()); - CHECK_EQ(channels * outLength, outputGrad.getWidth()); - - for (size_t k = 0; k < batch; k++) { - for (size_t c = 0; c < channels; c++) { - for (size_t i = 0; i < inLength; i++) { - size_t out_index = static_cast(maskData[i]); - if (out_index >= outLength) { - LOG(FATAL) << "upsample index " << out_index - << " out of range."; - } - inputGradData[i] = outputGradData[out_index]; - } - inputGradData += inLength; - maskData += inLength; - outputGradData += outLength; + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { + real* outputGradData = outputGrad.getData(); + real* maskData = mask.getData(); + real* inputGradData = data_; + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t batch = outputGrad.getHeight(); + CHECK(inLength == this->getWidth() / channels); + CHECK_EQ(batch, this->getHeight()); + CHECK_EQ(channels * outLength, outputGrad.getWidth()); + + for (size_t k = 0; k < batch; k++) { + for (size_t c = 0; c < channels; c++) { + for (size_t i = 0; i < inLength; i++) { + size_t out_index = static_cast(maskData[i]); + if (out_index >= outLength) { + LOG(FATAL) << "upsample index " << out_index << " out of range."; } + inputGradData[i] = outputGradData[out_index]; + } + inputGradData += inLength; + maskData += inLength; + outputGradData += outLength; } + } } void CpuMatrix::maxPoolForward(Matrix& inputMat, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index b4fcf05cb26..6e9ea04d669 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -860,22 +860,22 @@ public: } virtual void upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { LOG(FATAL) << "Not implemeted"; } virtual void upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW) { + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW) { LOG(FATAL) << "Not implemeted"; } @@ -1438,20 +1438,20 @@ public: void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); void upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void maxPoolForward(Matrix& inputMat, size_t imgSizeH, @@ -1726,20 +1726,20 @@ public: MatrixPtr clone(size_t height, size_t width, bool useGpu = false); void upsampleForward(Matrix& input, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void upsampleBackward(Matrix& outputGrad, - Matrix& mask, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t outputH, - size_t outputW); + Matrix& mask, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t outputH, + size_t outputW); void maxPoolForward(Matrix& inputMat, size_t imgSizeH, diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 067ca21d323..7563368ad7a 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -978,12 +978,14 @@ class Pad(Cfg): def __init__(self, channels, pad_c, pad_h, pad_w): self.add_keys(locals()) + @config_class class Upsample(Cfg): def __init__(self, scale, scale_y, pad_out_x, pad_out_y, upsample_size, upsample_size_y): self.add_keys(locals()) + @config_class class Norm(Cfg): def __init__(self, @@ -2393,6 +2395,7 @@ class SpatialPyramidPoolLayer(LayerBase): output_x = (pow(4, spp_conf.pyramid_height) - 1) / (4 - 1) self.set_cnn_layer(name, 1, output_x, spp_conf.image_conf.channels) + @config_layer('upsample') class UpsampleLayer(LayerBase): def __init__(self, name, inputs, **xargs): @@ -2407,9 +2410,10 @@ class UpsampleLayer(LayerBase): input_layer.height) upsample = self.inputs[0].upsample - output_x = 0 + output_x = 0 output_y = 0 output_size = 0 + if upsample.scale: self.config.inputs[0].upsample_conf.scale = upsample.scale self.config.inputs[0].upsample_conf.scale_y = upsample.scale_y @@ -2427,11 +2431,11 @@ class UpsampleLayer(LayerBase): output_size = image_conf.channels * output_x * output_y - self.set_layer_height_width(output_y, output_x) self.set_layer_depth(input_layer.depth) self.set_layer_size(output_size) + @config_layer('pad') class PadLayer(LayerBase): def __init__(self, name, inputs, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 95369000bbd..1ce603389dc 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2881,6 +2881,7 @@ def img_pool3d_layer(input, num_filters=num_channels, size=l.config.size) + @wrap_name_default("upsample") @layer_support() def upsample_layer(input, @@ -2930,6 +2931,7 @@ def upsample_layer(input, 'scale or upsample_size, there must be one to be designated' assert len(input) == 2, 'layer input size must be 2' + assert input[1].layer_type == LayerType.POOL_LAYER, \ 'the second input should be the MaxPoolWithMaskLayer' -- GitLab