From 95035908b4f47e61bad12d0ed49bf62a1734b2cf Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 14:27:42 +0800 Subject: [PATCH] add CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 129 +++++++++++++++++++++ paddle/math/cross_map_normal_op.h | 47 ++++++++ paddle/math/tests/test_matrixCompare.cpp | 137 ++++++++++++----------- 3 files changed, 248 insertions(+), 65 deletions(-) create mode 100644 paddle/math/cross_map_normal_op.cpp create mode 100644 paddle/math/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp new file mode 100644 index 00000000000..3eb51b5998f --- /dev/null +++ b/paddle/math/cross_map_normal_op.cpp @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cross_map_normal_op.h" + +namespace paddle { + +// NCHW +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + denoms = denoms.constant(1.0); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + real* denomsData = denoms.getData() + i * numCols; + real* inputData = inputs.getData() + i * numCols; + for (int c = 0; c < (int)channels; c++) { + CpuVector denom(imageSize, denomsData + c * imageSize); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + CpuVector input(imageSize, inputData + (c + s) * imageSize); + denom += input.square() * scale; + } + } + } + } + outputs = inputs * denoms.pow(-pow); +} + +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + std::function oneImage = [=](real* data, + size_t offset) { + return CpuVector(imageSize, data + offset); + }; + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + size_t sOffset = i * numCols; + real* inputGradData = inputsGrad.getData() + sOffset; + real* inputData = inputsValue.getData() + sOffset; + real* denomData = denoms.getData() + sOffset; + real* outputGradData = outputsGrad.getData() + sOffset; + real* outputData = outputsValue.getData() + sOffset; + + for (int c = 0; c < (int)channels; c++) { + size_t cOffset = c * imageSize; + CpuVector inputGrad = oneImage(inputGradData, cOffset); + CpuVector inputValue = oneImage(inputData, cOffset); + CpuVector denom = oneImage(denomData, cOffset); + CpuVector outputGrad = oneImage(outputGradData, cOffset); + + inputGrad = inputGrad + denom.pow(-pow) * outputGrad; + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + size_t offset = (c + s) * imageSize; + CpuVector output = oneImage(outputData, offset); + CpuVector outputGrad = oneImage(outputGradData, offset); + CpuVector denom = oneImage(denomData, offset); + + inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h new file mode 100644 index 00000000000..2f996072528 --- /dev/null +++ b/paddle/math/cross_map_normal_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/Matrix.h" + +namespace paddle { + +struct CrossMapNormal { + void operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +struct CrossMapNormalGrad { + void operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 5233a9af401..9bb1fdbdab8 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/gserver/tests/TestUtil.h" #include "paddle/utils/Stat.h" #include "TensorCheck.h" +#include "paddle/math/cross_map_normal_op.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1261,30 +1262,32 @@ TEST(Matrix, MaxOutFwdBwd) { } } } + void testCrossMapNormalFwd( int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { float scale = 1.5; float pow = 0.5; int width = imgSizeH * imgSizeW * channels; - MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); - - input->randomizeUniform(); - target->randomizeUniform(); - inputGpu->copyFrom(*input); - targetGpu->copyFrom(*target); - - target->crossMapNormalFwd( - *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); - targetGpu->crossMapNormalFwd( - *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); - - TensorCheckErr(*target, *targetGpu); - TensorCheckErr(*denorms, *denormsGpu); + CpuMatrix inputs(numSamples, width); + CpuMatrix denoms(numSamples, width); + CpuMatrix outputs(numSamples, width); + GpuMatrix inputsGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + GpuMatrix outputsGpu(numSamples, width); + + inputs.randomizeUniform(); + outputs.randomizeUniform(); + inputsGpu.copyFrom(inputs); + outputsGpu.copyFrom(outputs); + + CrossMapNormal cross; + cross( + outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + outputsGpu.crossMapNormalFwd( + inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(outputs, outputsGpu); + TensorCheckErr(denoms, denomsGpu); } TEST(Matrix, crossMapNormalFwd) { @@ -1310,53 +1313,57 @@ void testCrossMapNormalBwd( float scale = 1.5; float pow = 0.5; size_t width = imgSizeH * imgSizeW * channels; - MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); - - localGrad->randomizeUniform(); - denoms->randomizeUniform(); - preOutV->randomizeUniform(); - localOutV->randomizeUniform(); - output->randomizeUniform(); - denoms->add(0.01); - - MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); - - localGradGpu->copyFrom(*localGrad); - denomsGpu->copyFrom(*denoms); - preOutVGpu->copyFrom(*preOutV); - localOutVGpu->copyFrom(*localOutV); - outputGpu->copyFrom(*output); - output->crossMapNormalBwd(*localGrad, - *denoms, - *preOutV, - *localOutV, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - outputGpu->crossMapNormalBwd(*localGradGpu, - *denomsGpu, - *preOutVGpu, - *localOutVGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - TensorCheckErr(*output, *outputGpu); + CpuMatrix inputsGrad(numSamples, width); + CpuMatrix inputsValue(numSamples, width); + CpuMatrix outputsGrad(numSamples, width); + CpuMatrix outputsValue(numSamples, width); + CpuMatrix denoms(numSamples, width); + + outputsGrad.randomizeUniform(); + denoms.randomizeUniform(); + inputsValue.randomizeUniform(); + outputsValue.randomizeUniform(); + inputsGrad.randomizeUniform(); + denoms.add(0.01); + + GpuMatrix inputsGradGpu(numSamples, width); + GpuMatrix inputsValueGpu(numSamples, width); + GpuMatrix outputsGradGpu(numSamples, width); + GpuMatrix outputsValueGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + + outputsGradGpu.copyFrom(outputsGrad); + denomsGpu.copyFrom(denoms); + inputsValueGpu.copyFrom(inputsValue); + outputsValueGpu.copyFrom(outputsValue); + inputsGradGpu.copyFrom(inputsGrad); + + CrossMapNormalGrad cross; + cross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + inputsGradGpu.crossMapNormalBwd(outputsGradGpu, + denomsGpu, + inputsValueGpu, + outputsValueGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(inputsGrad, inputsGradGpu); } TEST(Matrix, crossMapNormalBwd) { -- GitLab