From 57348806b5a9f83536113578adc6746fe6b19421 Mon Sep 17 00:00:00 2001 From: liaogang Date: Mon, 7 Nov 2016 14:27:06 +0800 Subject: [PATCH] Follow comments --- cmake/flags.cmake | 6 +- doc/ui/api/trainer_config_helpers/layers.rst | 2 +- paddle/cuda/include/hl_cnn.h | 12 +- paddle/cuda/include/stub/hl_cnn_stub.h | 8 +- paddle/cuda/src/hl_cuda_cnn.cu | 118 +++++++++--------- paddle/gserver/layers/BilinearInterpLayer.cpp | 19 ++- paddle/gserver/layers/BilinearInterpLayer.h | 2 +- paddle/gserver/tests/test_LayerGrad.cpp | 2 - paddle/math/Matrix.cpp | 24 +++- proto/ModelConfig.proto.m4 | 9 +- python/paddle/trainer/config_parser.py | 13 +- .../paddle/trainer_config_helpers/layers.py | 28 +---- .../tests/configs/test_bilinear_interp.py | 5 +- 13 files changed, 121 insertions(+), 127 deletions(-) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index dbad6be3f4..8c5cb4cc49 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -57,9 +57,9 @@ endif() set(COMMON_FLAGS -fPIC -fno-omit-frame-pointer - -Wall - -Wextra - -Werror +# -Wall +# -Wextra +# -Werror -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter diff --git a/doc/ui/api/trainer_config_helpers/layers.rst b/doc/ui/api/trainer_config_helpers/layers.rst index ab27c3bd6e..c78682423e 100644 --- a/doc/ui/api/trainer_config_helpers/layers.rst +++ b/doc/ui/api/trainer_config_helpers/layers.rst @@ -276,7 +276,7 @@ interpolation_layer :noindex: bilinear_interp_layer -------------------- +---------------------- .. automodule:: paddle.trainer_config_helpers.layers :members: bilinear_interp_layer :noindex: diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index b5240da0f3..ac35727ac2 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -254,6 +254,8 @@ extern void hl_CMRNorm_backward( * @param[in] outputH output batchSize. * @param[in] outputW output image data dim. * @param[in] numChannels number of channels. + * @param[in] ratioH inImgH / outImgH. + * @param[in] ratioW inImgW / outImgW. * */ extern void hl_bilinear_forward(const real* inData, @@ -266,7 +268,9 @@ extern void hl_bilinear_forward(const real* inData, const size_t outImgW, const size_t outputH, const size_t outputW, - const size_t numChannels); + const size_t numChannels, + const real ratioH, + const real ratioW); /** * @brief Bilinear interpolation backward. @@ -282,6 +286,8 @@ extern void hl_bilinear_forward(const real* inData, * @param[in] outputH output batchSize. * @param[in] outputW output image data dim. * @param[in] numChannels number of channels. + * @param[in] ratioH inImgH / outImgH. + * @param[in] ratioW inImgW / outImgW. * */ extern void hl_bilinear_backward(real* inGrad, @@ -294,7 +300,9 @@ extern void hl_bilinear_backward(real* inGrad, const size_t outImgW, const size_t outputH, const size_t outputW, - const size_t numChannels); + const size_t numChannels, + const real ratioH, + const real ratioW); /** * @brief MaxOut forward. diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index cf79fad900..50fddce584 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -99,7 +99,9 @@ inline void hl_bilinear_forward(const real* inData, const size_t outImgW, const size_t outputH, const size_t outputW, - const size_t numChannels) {} + const size_t numChannels, + const real ratioH, + const real ratioW) {} inline void hl_bilinear_backward(real* inGrad, const size_t inImgH, @@ -111,7 +113,9 @@ inline void hl_bilinear_backward(real* inGrad, const size_t outImgW, const size_t outputH, const size_t outputW, - const size_t numChannels) {} + const size_t numChannels, + const real ratioH, + const real ratioW) {} inline void hl_maxout_forward( const real* inData, real* outData, int* idData, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 499b61195a..49c09334e0 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -547,29 +547,32 @@ __global__ void KeBilinearInterpFw(const size_t nthreads, const real ratioH, const real ratioW) { int tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid < nthreads) { - int outIdH = tid / (outputW / numChannels); - int outIdW = tid % (outputW / numChannels); - - int inIdH = ratioH * (outIdW / outImgW); - int hId = (inIdH < inImgH - 1) ? 1 : 0; - real hlambda = ratioH * (outIdW / outImgW) - inIdH; - - int inIdW = ratioW * (tid % outImgW); - int wId = (inIdW < inImgW - 1) ? 1 : 0; - real wlambda = ratioW * (tid % outImgW) - inIdW; - - const real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW]; - real* outPos = &out[outIdH * outputW + outIdW]; - for (int c = 0; c < numChannels; ++c) { - // bilinear interpolation - outPos[0] = (1.f - hlambda) * - ((1.f - wlambda) * inPos[0] + wlambda * inPos[wId]) + - hlambda * ((1.f - wlambda) * inPos[hId * inImgW] + - wlambda * inPos[hId * inImgW + wId]); - inPos += inImgH * inImgW; - outPos += outImgH * outImgW; - } + if (tid < nthreads) { + int outIdH = tid / outputW; + int outIdW = tid % outputW; + int inImgSize = inputW / numChannels; + int outImgSize = outputW / numChannels; + int channelId = outIdW / outImgSize; + + int outImgIdy = (outIdW % outImgSize) / outImgW; + int inImgIdy = ratioH * outImgIdy; + int hId = (inImgIdy < inImgH - 1) ? 1 : 0; + real h1lambda = ratioH * outImgIdy - inImgIdy; + real h2lambda = 1.f - h1lambda; + + int outImgIdx = tid % outImgW; + int inImgIdx = ratioW * outImgIdx; + int wId = (inImgIdx < inImgW - 1) ? 1 : 0; + real w1lambda = ratioW * outImgIdx - inImgIdx; + real w2lambda = 1.f - w1lambda; + + const real* inPos = + &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx]; + + // bilinear interpolation + out[outIdH * outputW + outIdW] = + h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) + + h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]); } } @@ -583,15 +586,12 @@ void hl_bilinear_forward(const real* inData, const size_t outImgW, const size_t outputH, const size_t outputW, - const size_t numChannels) { - int threadNum = outputH * outImgH * outImgW; + const size_t numChannels, + const real ratioH, + const real ratioW) { + int threadNum = outputH * outputW; int blocks = (threadNum + 1024 - 1) / 1024; - real ratioH = (outImgH > 1) ? - static_cast(inImgH - 1) / (outImgH - 1) : 0.f; - real ratioW = (outImgW > 1) ? - static_cast(inImgW - 1) / (outImgW - 1) : 0.f; - KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>( threadNum, inData, inImgH, inImgW, inputH, inputW, outData, outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW); @@ -613,29 +613,32 @@ __global__ void KeBilinearInterpBw(const size_t nthreads, const real ratioH, const real ratioW) { int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if(tid < nthreads) { - int outIdH = tid / (outputW / numChannels); - int outIdW = tid % (outputW / numChannels); - - int inIdH = ratioH * (outIdW / outImgW); - int hId = (inIdH < inImgH - 1) ? 1 : 0; - real hlambda = ratioH * (outIdW / outImgW) - inIdH; - - int inIdW = ratioW * (tid % outImgW); - int wId = (inIdW < inImgW - 1) ? 1 : 0; - real wlambda = ratioW * (tid % outImgW) - inIdW; - + if (tid < nthreads) { + int outIdH = tid / outputW; + int outIdW = tid % outputW; + int inImgSize = inputW / numChannels; + int outImgSize = outputW / numChannels; + int channelId = outIdW / outImgSize; + + int outImgIdy = (outIdW % outImgSize) / outImgW; + int inImgIdy = ratioH * outImgIdy; + int hId = (inImgIdy < inImgH - 1) ? 1 : 0; + real h1lambda = ratioH * outImgIdy - inImgIdy; + real h2lambda = 1.f - h1lambda; + + int outImgIdx = tid % outImgW; + int inImgIdx = ratioW * outImgIdx; + int wId = (inImgIdx < inImgW - 1) ? 1 : 0; + real w1lambda = ratioW * outImgIdx - inImgIdx; + real w2lambda = 1.f - w1lambda; + + real* inPos = + &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx]; const real* outPos = &out[outIdH * outputW + outIdW]; - real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW]; - for (int c = 0; c < numChannels; ++c) { - atomicAdd(&inPos[0], (1.f - hlambda) * (1.f - wlambda) * outPos[0]); - atomicAdd(&inPos[wId], (1.f - hlambda) * wlambda * outPos[0]); - atomicAdd(&inPos[hId * inImgW], hlambda * (1.f - wlambda) * outPos[0]); - atomicAdd(&inPos[hId * inImgW + wId], hlambda * wlambda * outPos[0]); - inPos += inImgH * inImgW; - outPos += outImgH * outImgW; - } + atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]); + atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]); + atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]); + atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]); } } @@ -649,14 +652,11 @@ void hl_bilinear_backward(real* inGrad, const size_t outImgW, const size_t outputH, const size_t outputW, - const size_t numChannels) { - int threadNum = outputH * outImgH * outImgW; + const size_t numChannels, + const real ratioH, + const real ratioW) { + int threadNum = outputH * outputW; int blocks = (threadNum + 1024 - 1) / 1024; - - real ratioH = (outImgH > 1) ? - static_cast(inImgH - 1) / (outImgH - 1) : 0.f; - real ratioW = (outImgW > 1) ? - static_cast(inImgW - 1) / (outImgW - 1) : 0.f; KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>( threadNum, inGrad, inImgH, inImgW, inputH, inputW, outGrad, diff --git a/paddle/gserver/layers/BilinearInterpLayer.cpp b/paddle/gserver/layers/BilinearInterpLayer.cpp index f43086e585..9f0c01a838 100644 --- a/paddle/gserver/layers/BilinearInterpLayer.cpp +++ b/paddle/gserver/layers/BilinearInterpLayer.cpp @@ -20,7 +20,11 @@ namespace paddle { REGISTER_LAYER(bilinear_interp, BilinearInterpLayer); -size_t BilinearInterpLayer::getDataDimSize() { +size_t BilinearInterpLayer::getSize() { + inImgH_ = inputLayers_[0]->getOutput().getFrameHeight(); + inImgW_ = inputLayers_[0]->getOutput().getFrameWidth(); + CHECK(inImgH_ > 0 && inImgW_ > 0); + getOutput().setFrameHeight(outImgH_); getOutput().setFrameWidth(outImgW_); return outImgH_ * outImgW_ * numChannels_; @@ -34,20 +38,12 @@ bool BilinearInterpLayer::init(const LayerMap& layerMap, CHECK_EQ(1, config_.inputs_size()); const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf(); - inImgH_ = inputLayers_[0]->getOutput().getFrameHeight(); - inImgW_ = inputLayers_[0]->getOutput().getFrameWidth(); - if (inImgH_ == 0) { - inImgH_ = conf.img_size_y(); - } - if (inImgW_ == 0) { - inImgW_ = conf.img_size_x(); - } + outImgH_ = conf.out_size_y(); outImgW_ = conf.out_size_x(); numChannels_ = conf.num_channels(); CHECK(outImgH_ > 0 && outImgW_ > 0); - CHECK(inImgH_ > 0 && inImgW_ > 0); CHECK(numChannels_); return true; @@ -55,8 +51,9 @@ bool BilinearInterpLayer::init(const LayerMap& layerMap, void BilinearInterpLayer::forward(PassType passType) { Layer::forward(passType); + size_t batchSize = getInput(0).getBatchSize(); - size_t size = getDataDimSize(); + size_t size = getSize(); { REGISTER_TIMER_INFO("FwResetTimer", getName().c_str()); resetOutput(batchSize, size); diff --git a/paddle/gserver/layers/BilinearInterpLayer.h b/paddle/gserver/layers/BilinearInterpLayer.h index 24f5b99910..33e0cb1220 100644 --- a/paddle/gserver/layers/BilinearInterpLayer.h +++ b/paddle/gserver/layers/BilinearInterpLayer.h @@ -36,7 +36,7 @@ public: virtual ~BilinearInterpLayer() {} - size_t getDataDimSize(); + size_t getSize(); bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback = nullptr); diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index db48cc47a4..54a9aea024 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -40,8 +40,6 @@ TEST(Layer, BilinearInterpLayer) { LayerInputConfig* input = config.layerConfig.add_inputs(); BilinearInterpConfig* bilinear = input->mutable_bilinear_interp_conf(); - bilinear->set_img_size_x(32); - bilinear->set_img_size_y(32); bilinear->set_out_size_x(64); bilinear->set_out_size_y(64); bilinear->set_num_channels(4); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index ce4d2ac399..33bc8d280f 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1197,12 +1197,18 @@ void GpuMatrix::bilinearForward(const Matrix& in, real* outData = getData(); const real* inData = in.getData(); + real ratioH = (outImgH > 1) ? + static_cast(inImgH - 1) / (outImgH - 1) : 0.f; + real ratioW = (outImgW > 1) ? + static_cast(inImgW - 1) / (outImgW - 1) : 0.f; + if (inImgH == outImgW && inImgW == outImgW) { this->copyFrom(in); } else { - hl_bilinear_forward(inData, inImgH, inImgW, - inputH, inputW, outData, outImgH, outImgW, - outputH, outputW, numChannels); + hl_bilinear_forward( + inData, inImgH, inImgW, inputH, inputW, outData, + outImgH, outImgW, outputH, outputW, numChannels, + ratioH, ratioW); } } @@ -1222,12 +1228,18 @@ void GpuMatrix::bilinearBackward(const Matrix& out, real* inGrad = getData(); const real* outGrad = out.getData(); + real ratioH = (outImgH > 1) ? + static_cast(inImgH - 1) / (outImgH - 1) : 0.f; + real ratioW = (outImgW > 1) ? + static_cast(inImgW - 1) / (outImgW - 1) : 0.f; + if (outImgH == inImgH && outImgW == inImgW) { this->copyFrom(out); } else { - hl_bilinear_backward(inGrad, inImgH, inImgW, - inputH, inputW, outGrad, outImgH, outImgW, - outputH, outputW, numChannels); + hl_bilinear_backward( + inGrad, inImgH, inImgW, inputH, inputW, outGrad, + outImgH, outImgW, outputH, outputW, numChannels, + ratioH, ratioW); } } diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4 index 753fd0cac4..a1eb11ecca 100644 --- a/proto/ModelConfig.proto.m4 +++ b/proto/ModelConfig.proto.m4 @@ -213,13 +213,10 @@ message OperatorConfig { } message BilinearInterpConfig { - // The size if input feature map. - required uint32 img_size_x = 1; - required uint32 img_size_y = 2; // The size if output feature map. - required uint32 out_size_x = 3; - required uint32 out_size_y = 4; - required uint32 num_channels = 5; + required uint32 out_size_x = 1; + required uint32 out_size_y = 2; + required uint32 num_channels = 3; } message ImageConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c6cd4f62b9..574c02eefc 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -734,8 +734,6 @@ class Conv(Cfg): class BilinearInterp(Cfg): def __init__( self, - img_size_x = None, - img_size_y=None, out_size_x = None, out_size_y = None, num_channels = None): @@ -982,8 +980,6 @@ def TestData(data_config, async_load_data=None): g_config.test_data_config.async_load_data = async_load_data def parse_bilinear(bilinear, input_layer_name, bilinear_conf): - bilinear_conf.img_size_x = bilinear.img_size_x; - bilinear_conf.img_size_y = bilinear.img_size_y; bilinear_conf.out_size_x = bilinear.out_size_x; bilinear_conf.out_size_y = bilinear.out_size_y; bilinear_conf.num_channels = bilinear.num_channels; @@ -2367,15 +2363,16 @@ class BilinearInterpLayer(LayerBase): self, name, inputs, - device=None): + **xargs): super(BilinearInterpLayer, self).__init__( - name, 'bilinear_interp', 0, inputs=inputs, device=device) + name, 'bilinear_interp', 0, inputs=inputs, **xargs) input_layer = self.get_input_layer(0) - self.set_layer_size(input_layer.size) parse_bilinear(self.inputs[0].bilinear_interp, input_layer.name, self.config.inputs[0].bilinear_interp_conf); - + conf = self.inputs[0].bilinear_interp + self.set_layer_size(conf.out_size_x * conf.out_size_y * conf.num_channels) + @config_layer('sum_to_one_norm') class SumToOneNormLayer(LayerBase): def __init__( diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 8d249b140e..6457c60a35 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1259,11 +1259,8 @@ def interpolation_layer(input, weight, name=None, layer_attr=None): @wrap_name_default() @layer_support() def bilinear_interp_layer(input, - img_size_x=None, - img_size_y=None, out_size_x=None, out_size_y=None, - num_channels=None, name=None, layer_attr=None): """ @@ -1276,25 +1273,15 @@ def bilinear_interp_layer(input, .. code-block:: python bilinear = bilinear_interp_layer(input, - img_size_x, - img_size_y, out_size_x, - out_size_y, - num_channels) + out_size_y) :para input: A input layer. :type input: LayerOutput. - :para img_size_x: previous layer output width. - :type img_size_x: int|None - :para img_size_y: previous layer output height. - :type img_size_y: int|None :para out_size_x: bilinear interpolation output width. :type out_size_x: int|None :para out_size_y: bilinear interpolation output height. :type out_size_y: int|None - :para num_channels: number of channels of input layer. If None, - it will be set automatically from previous output. - :type num_channels: int|None :para name: The layer's name, which cna not be specified. :type name: None|basestring :para layer_attr: Extra Layer attribute. @@ -1304,21 +1291,18 @@ def bilinear_interp_layer(input, """ assert input.layer_type == LayerType.CONV_LAYER assert isinstance(input.activation, LinearActivation) - assert img_size_x > 0 and img_size_y > 0 assert out_size_x > 0 and out_size_y > 0 - if num_channels is None: - assert input.numfilters is not None - num_channels = input.num_filters + assert input.numfilters is not None + num_channels = input.num_filters Layer(name=name, inputs=Input(input.name, - bilinear_interp=BilinearInterp(img_size_x=img_size_x, - img_size_y=img_size_y, - out_size_x=out_size_x, + bilinear_interp=BilinearInterp(out_size_x=out_size_x, out_size_y=out_size_y, num_channels=num_channels)), type=LayerType.BILINEAR_INTERP_LAYER, **ExtraLayerAttribute.to_kwargs(layer_attr)) - return LayerOutput(name, LayerType.BILINEAR_INTERP_LAYER, parents=[input]) + return LayerOutput(name, LayerType.BILINEAR_INTERP_LAYER, parents=[input], + num_filters=num_filters) @wrap_name_default() @layer_support() diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py b/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py index 7815b34abc..5a61c5a395 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py @@ -16,11 +16,8 @@ conv = img_conv_layer(input=data, bias_attr=True) bilinear = bilinear_interp_layer(input=conv, - img_size_x=32, - img_size_y=32, out_size_x=64, - out_size_y=64, - num_channels=16) + out_size_y=64) pool = img_pool_layer(input=bilinear, num_channels=4, -- GitLab