diff --git a/paddle/gserver/layers/CrossChannelNormLayer.cpp b/paddle/gserver/layers/CrossChannelNormLayer.cpp index 0c8156ae7736d923bc55df696808fad09bfa1796..4c952742934dfde8b37a09f9152fa4ceb9a04f02 100644 --- a/paddle/gserver/layers/CrossChannelNormLayer.cpp +++ b/paddle/gserver/layers/CrossChannelNormLayer.cpp @@ -19,6 +19,23 @@ limitations under the License. */ namespace paddle { +MatrixPtr CrossChannelNormLayer::createSampleMatrix(MatrixPtr data, + size_t iter, + size_t spatialDim) { + return Matrix::create(data->getData() + iter * channels_ * spatialDim, + channels_, + spatialDim, + false, + useGpu_); +} + +MatrixPtr CrossChannelNormLayer::createSpatialMatrix(MatrixPtr data, + size_t iter, + size_t spatialDim) { + return Matrix::create( + data->getData() + iter * spatialDim, 1, spatialDim, false, useGpu_); +} + void CrossChannelNormLayer::forward(PassType passType) { Layer::forward(passType); MatrixPtr inV = getInputValue(0); @@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) { normBuffer_->addScalar(*normBuffer_, 1e-6); inV->square2(*dataBuffer_); for (size_t i = 0; i < batchSize; i++) { - MatrixPtr inTmp = Matrix::create( - inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_); - MatrixPtr dataTmp = Matrix::create(dataBuffer_->getData() + i * dataDim, - channels_, - spatialDim, - false, - useGpu_); - MatrixPtr outTmp = Matrix::create( - outV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_); - MatrixPtr normTmp = Matrix::create( - normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_); + const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim); + const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim); + MatrixPtr outVTmp = createSampleMatrix(outV, i, spatialDim); + MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim); + // compute norm. - spatialBuffer_->sumCols(*dataTmp, 1, 1); + spatialBuffer_->sumCols(*dataTmp, 1, 0); spatialBuffer_->sqrt2(*spatialBuffer_); normTmp->copyFrom(*spatialBuffer_); - outTmp->copyFrom(*inTmp); - outTmp->divRowVector(*spatialBuffer_); + outVTmp->copyFrom(*inVTmp); + outVTmp->divRowVector(*spatialBuffer_); // scale the layer. - outTmp->mulColVector(*scale_->getW()); + outVTmp->mulColVector(*scale_->getW()); } } @@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) { Matrix::resizeOrCreate(sampleBuffer_, channels_, spatialDim, false, useGpu_); scaleDiff_->zeroMem(); for (size_t i = 0; i < batchSize; i++) { - // propagate to param. - MatrixPtr dataBufferTmp = - Matrix::create(dataBuffer_->getData() + i * dataDim, - channels_, - spatialDim, - false, - useGpu_); - const MatrixPtr inValueTmp = Matrix::create( - inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_); - const MatrixPtr outGradTmp = Matrix::create( - outG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_); - MatrixPtr inGradTmp = Matrix::create( - inG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_); - const MatrixPtr normTmp = Matrix::create( - normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_); - channelBuffer_->sumRows(*dataBufferTmp, 1, 1); + MatrixPtr outGTmp = createSampleMatrix(outG, i, spatialDim); + const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim); + const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim); + const MatrixPtr inGTmp = createSampleMatrix(inG, i, spatialDim); + const MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim); + + channelBuffer_->sumRows(*dataTmp, 1, 0); channelBuffer_->dotDiv(*channelBuffer_, *(scale_->getW())); // store a / scale[i] in scaleDiff_ temporary scaleDiff_->add(*channelBuffer_, 1.); - sampleBuffer_->dotMul(*inValueTmp, *outGradTmp); + sampleBuffer_->dotMul(*inVTmp, *outGTmp); spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.); // scale the grad - inGradTmp->copyFrom(*inValueTmp); - inGradTmp->mulRowVector(*spatialBuffer_); + inGTmp->copyFrom(*inVTmp); + inGTmp->mulRowVector(*spatialBuffer_); // divide by square of norm spatialBuffer_->dotMul(*normTmp, *normTmp); - inGradTmp->divRowVector(*spatialBuffer_); + inGTmp->divRowVector(*spatialBuffer_); // subtract - inGradTmp->add(*outGradTmp, -1, 1); + inGTmp->add(*outGTmp, -1, 1); // divide by norm - inGradTmp->divRowVector(*normTmp); + inGTmp->divRowVector(*normTmp); // scale the diff - inGradTmp->mulColVector(*scale_->getW()); + inGTmp->mulColVector(*scale_->getW()); } // updata scale if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_); diff --git a/paddle/gserver/layers/NormLayer.h b/paddle/gserver/layers/NormLayer.h index f490f506a9028580f590c013e25d3f54d2ae470c..7c238ac944e52c3a83c2aa5deac18de3aff6db61 100644 --- a/paddle/gserver/layers/NormLayer.h +++ b/paddle/gserver/layers/NormLayer.h @@ -80,9 +80,10 @@ public: explicit CrossChannelNormLayer(const LayerConfig& config) : NormLayer(config) {} bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - void forward(PassType passType); void backward(const UpdateCallback& callback); + MatrixPtr createSampleMatrix(MatrixPtr data, size_t iter, size_t spatialDim); + MatrixPtr createSpatialMatrix(MatrixPtr data, size_t iter, size_t spatialDim); protected: size_t channels_;