提交 3bce32ba 编写于 作者: G gaoyuan

Add create matrix pointer funtion

上级 17c697c7
......@@ -19,6 +19,23 @@ limitations under the License. */
namespace paddle {
MatrixPtr CrossChannelNormLayer::createSampleMatrix(MatrixPtr data,
size_t iter,
size_t spatialDim) {
return Matrix::create(data->getData() + iter * channels_ * spatialDim,
channels_,
spatialDim,
false,
useGpu_);
}
MatrixPtr CrossChannelNormLayer::createSpatialMatrix(MatrixPtr data,
size_t iter,
size_t spatialDim) {
return Matrix::create(
data->getData() + iter * spatialDim, 1, spatialDim, false, useGpu_);
}
void CrossChannelNormLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV = getInputValue(0);
......@@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) {
normBuffer_->addScalar(*normBuffer_, 1e-6);
inV->square2(*dataBuffer_);
for (size_t i = 0; i < batchSize; i++) {
MatrixPtr inTmp = Matrix::create(
inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
MatrixPtr dataTmp = Matrix::create(dataBuffer_->getData() + i * dataDim,
channels_,
spatialDim,
false,
useGpu_);
MatrixPtr outTmp = Matrix::create(
outV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
MatrixPtr normTmp = Matrix::create(
normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_);
const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim);
const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim);
MatrixPtr outVTmp = createSampleMatrix(outV, i, spatialDim);
MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim);
// compute norm.
spatialBuffer_->sumCols(*dataTmp, 1, 1);
spatialBuffer_->sumCols(*dataTmp, 1, 0);
spatialBuffer_->sqrt2(*spatialBuffer_);
normTmp->copyFrom(*spatialBuffer_);
outTmp->copyFrom(*inTmp);
outTmp->divRowVector(*spatialBuffer_);
outVTmp->copyFrom(*inVTmp);
outVTmp->divRowVector(*spatialBuffer_);
// scale the layer.
outTmp->mulColVector(*scale_->getW());
outVTmp->mulColVector(*scale_->getW());
}
}
......@@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
Matrix::resizeOrCreate(sampleBuffer_, channels_, spatialDim, false, useGpu_);
scaleDiff_->zeroMem();
for (size_t i = 0; i < batchSize; i++) {
// propagate to param.
MatrixPtr dataBufferTmp =
Matrix::create(dataBuffer_->getData() + i * dataDim,
channels_,
spatialDim,
false,
useGpu_);
const MatrixPtr inValueTmp = Matrix::create(
inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
const MatrixPtr outGradTmp = Matrix::create(
outG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
MatrixPtr inGradTmp = Matrix::create(
inG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
const MatrixPtr normTmp = Matrix::create(
normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_);
channelBuffer_->sumRows(*dataBufferTmp, 1, 1);
MatrixPtr outGTmp = createSampleMatrix(outG, i, spatialDim);
const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim);
const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim);
const MatrixPtr inGTmp = createSampleMatrix(inG, i, spatialDim);
const MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim);
channelBuffer_->sumRows(*dataTmp, 1, 0);
channelBuffer_->dotDiv(*channelBuffer_, *(scale_->getW()));
// store a / scale[i] in scaleDiff_ temporary
scaleDiff_->add(*channelBuffer_, 1.);
sampleBuffer_->dotMul(*inValueTmp, *outGradTmp);
sampleBuffer_->dotMul(*inVTmp, *outGTmp);
spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.);
// scale the grad
inGradTmp->copyFrom(*inValueTmp);
inGradTmp->mulRowVector(*spatialBuffer_);
inGTmp->copyFrom(*inVTmp);
inGTmp->mulRowVector(*spatialBuffer_);
// divide by square of norm
spatialBuffer_->dotMul(*normTmp, *normTmp);
inGradTmp->divRowVector(*spatialBuffer_);
inGTmp->divRowVector(*spatialBuffer_);
// subtract
inGradTmp->add(*outGradTmp, -1, 1);
inGTmp->add(*outGTmp, -1, 1);
// divide by norm
inGradTmp->divRowVector(*normTmp);
inGTmp->divRowVector(*normTmp);
// scale the diff
inGradTmp->mulColVector(*scale_->getW());
inGTmp->mulColVector(*scale_->getW());
}
// updata scale
if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_);
......
......@@ -80,9 +80,10 @@ public:
explicit CrossChannelNormLayer(const LayerConfig& config)
: NormLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback);
MatrixPtr createSampleMatrix(MatrixPtr data, size_t iter, size_t spatialDim);
MatrixPtr createSpatialMatrix(MatrixPtr data, size_t iter, size_t spatialDim);
protected:
size_t channels_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册