From 4ebb3eb759903bf95968b578eec99b1364d3bd10 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 11:55:35 +0800 Subject: [PATCH] imporve Function --- paddle/gserver/layers/NormProjectionLayer.cpp | 60 +++++++++++---- paddle/gserver/layers/NormProjectionLayer.h | 4 + paddle/math/Function.cpp | 6 +- paddle/math/Function.h | 14 ++-- paddle/math/cross_map_normal_op.cpp | 75 ++++++++++--------- paddle/math/cross_map_normal_op.h | 13 ++++ paddle/math/cross_map_normal_op_gpu.cu | 46 ++++-------- paddle/math/tests/test_matrixCompare.cpp | 21 +++++- 8 files changed, 147 insertions(+), 92 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index ea301292e..5dda7ee20 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" +#include "paddle/math/cross_map_normal_op.h" #include "NormProjectionLayer.h" namespace paddle { @@ -45,6 +46,16 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); + if (useGpu_) { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, GPU)); + } else { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, CPU)); + } + normal_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + return true; } @@ -62,10 +73,14 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - denoms_->zeroMem(); - - outV->crossMapNormalFwd( - *input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_); + Dims dims{(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + normal_->calc( + {Tensor(input->getData(), dims)}, + {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + {}); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { @@ -80,15 +95,32 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - preOutGrad->crossMapNormalBwd(*localGrad, - *denoms_, - *preOutV, - *localOutV, - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); + if (useGpu_) { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } else { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 0db8e2551..ea44669be 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "NormLayer.h" #include "paddle/math/Matrix.h" +#include "paddle/math/Function.h" #include namespace paddle { @@ -39,5 +40,8 @@ public: bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback = nullptr); + +protected: + FunctionBase* normal_; }; } // namespace paddle diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp index 21d271917..02880e5ea 100644 --- a/paddle/math/Function.cpp +++ b/paddle/math/Function.cpp @@ -31,15 +31,17 @@ real FuncConfig::get(const std::string& key) const { } template <> -void FuncConfig::set(const std::string& key, size_t v) { +FuncConfig& FuncConfig::set(const std::string& key, size_t v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].s = v; + return *this; } template <> -void FuncConfig::set(const std::string& key, real v) { +FuncConfig& FuncConfig::set(const std::string& key, real v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].r = v; + return *this; } ClassRegistrar FunctionBase::funcRegistrar_; diff --git a/paddle/math/Function.h b/paddle/math/Function.h index 539759782..f8fab972a 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -46,6 +46,8 @@ class Tensor { public: Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + real* getData() const { return buf_; } + real* buf_; Dims dims_; }; @@ -63,7 +65,7 @@ public: T get(const std::string& key) const; template - void set(const std::string& key, T v); + FuncConfig& set(const std::string& key, T v); protected: std::map valueMap_; @@ -84,11 +86,11 @@ public: #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName -#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ - static InitFunction __reg_type_##typeName([]() { \ - FunctionBase::funcRegistrar_ \ - .registerClass>( \ - FUNC_NAME(typeName, deviceName)); \ +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName##deviceName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ }) } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index d55bd78c6..e520351d2 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -18,45 +18,41 @@ namespace paddle { // NCHW template <> -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - - denoms = denoms.constant(1.0); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - real* denomsData = denoms.getData() + i * numCols; - real* inputData = inputs.getData() + i * numCols; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneImage = height * width; + size_t oneSample = channels * oneImage; + + CpuVector outputsV(numSamples * oneSample, outputs); + CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector denomsV(numSamples * oneSample, denoms); + + denomsV = denomsV.constant(1.0); + const int start = -((int)size - 1) / 2; + const int end = (int)size + start; + for (size_t i = 0; i < numSamples; i++) { + real* oneDenom = denoms + i * oneSample; + real* oneInput = inputs + i * oneSample; for (int c = 0; c < (int)channels; c++) { - CpuVector denom(imageSize, denomsData + c * imageSize); + CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - CpuVector input(imageSize, inputData + (c + s) * imageSize); + CpuVector input(oneImage, oneInput + (c + s) * oneImage); denom += input.square() * scale; } } } } - outputs = inputs * denoms.pow(-pow); + + outputsV = inputsV * denomsV.pow(-pow); } template <> @@ -154,13 +150,17 @@ public: size_t channels = inputs[0].dims_[1]; size_t height = inputs[0].dims_[2]; size_t width = inputs[0].dims_[3]; - size_t imageSize = channels * height * width; - CpuMatrix input(inputs[0].buf_, samples, imageSize); - CpuMatrix output(outputs[0].buf_, samples, imageSize); - CpuMatrix denom(outputs[1].buf_, samples, imageSize); - CrossMapNormal cross; - cross(output, denom, input, channels, height, width, size_, scale_, pow_); + CrossMapNormal(outputs[0].getData(), + outputs[1].getData(), + inputs[0].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); } private: @@ -170,5 +170,6 @@ private: }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 86f54abde..ef9533485 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -19,6 +19,18 @@ limitations under the License. */ namespace paddle { +template +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); +#if 0 template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, @@ -31,6 +43,7 @@ struct CrossMapNormal { real scale, real pow); }; +#endif template struct CrossMapNormalGrad { diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 0a154d97a..9b9297434 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -61,45 +61,29 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, } template <> -void CrossMapNormal::operator()(GpuMatrix& outputs, - GpuMatrix& denoms, - GpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - real* inputsData = inputs.getData(); - real* denomsData = denoms.getData(); - real* outputsData = outputs.getData(); - - size_t imageSize = numSample * imgSizeH * imgSizeW; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormFillScale<<>> - (imageSize, inputsData, denomsData, - channels, imgSizeH, imgSizeW, sizeX, scale); + (imageSize, inputs, denoms, channels, height, width, size, scale); - size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + size_t inputSize = numSamples * height * width *channels; blockSize = 1024; gridSize = (inputSize + 1024 - 1) / 1024; KeCMRNormOutput<<>> - (inputSize, inputsData, denomsData, -pow, outputsData); + (inputSize, inputs, denoms, -pow, outputs); - CHECK_SYNC("CrossMapNormalFwd"); + CHECK_SYNC("CrossMapNormal"); } __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index cd34ea18a..aac3f7579 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1281,24 +1281,40 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); +#if 0 FuncConfig config; config.set("size", (size_t)sizeX); config.set("scale", scale); config.set("pow", pow); +#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - cpu->init(config); + FunctionBase* gpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); Dims dims{ (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; cpu->calc({Tensor(inputs.getData(), dims)}, {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, {}); + + gpu->calc( + {Tensor(inputsGpu.getData(), dims)}, + {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, + {}); #if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); -#endif + CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, @@ -1309,6 +1325,7 @@ void testCrossMapNormalFwd( sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); -- GitLab