提交 3bce32ba 编写于 作者: G gaoyuan

Add create matrix pointer funtion

上级 17c697c7
...@@ -19,6 +19,23 @@ limitations under the License. */ ...@@ -19,6 +19,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
MatrixPtr CrossChannelNormLayer::createSampleMatrix(MatrixPtr data,
size_t iter,
size_t spatialDim) {
return Matrix::create(data->getData() + iter * channels_ * spatialDim,
channels_,
spatialDim,
false,
useGpu_);
}
MatrixPtr CrossChannelNormLayer::createSpatialMatrix(MatrixPtr data,
size_t iter,
size_t spatialDim) {
return Matrix::create(
data->getData() + iter * spatialDim, 1, spatialDim, false, useGpu_);
}
void CrossChannelNormLayer::forward(PassType passType) { void CrossChannelNormLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
MatrixPtr inV = getInputValue(0); MatrixPtr inV = getInputValue(0);
...@@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) { ...@@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) {
normBuffer_->addScalar(*normBuffer_, 1e-6); normBuffer_->addScalar(*normBuffer_, 1e-6);
inV->square2(*dataBuffer_); inV->square2(*dataBuffer_);
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
MatrixPtr inTmp = Matrix::create( const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim);
inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_); const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim);
MatrixPtr dataTmp = Matrix::create(dataBuffer_->getData() + i * dataDim, MatrixPtr outVTmp = createSampleMatrix(outV, i, spatialDim);
channels_, MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim);
spatialDim,
false,
useGpu_);
MatrixPtr outTmp = Matrix::create(
outV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
MatrixPtr normTmp = Matrix::create(
normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_);
// compute norm. // compute norm.
spatialBuffer_->sumCols(*dataTmp, 1, 1); spatialBuffer_->sumCols(*dataTmp, 1, 0);
spatialBuffer_->sqrt2(*spatialBuffer_); spatialBuffer_->sqrt2(*spatialBuffer_);
normTmp->copyFrom(*spatialBuffer_); normTmp->copyFrom(*spatialBuffer_);
outTmp->copyFrom(*inTmp); outVTmp->copyFrom(*inVTmp);
outTmp->divRowVector(*spatialBuffer_); outVTmp->divRowVector(*spatialBuffer_);
// scale the layer. // scale the layer.
outTmp->mulColVector(*scale_->getW()); outVTmp->mulColVector(*scale_->getW());
} }
} }
...@@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) { ...@@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
Matrix::resizeOrCreate(sampleBuffer_, channels_, spatialDim, false, useGpu_); Matrix::resizeOrCreate(sampleBuffer_, channels_, spatialDim, false, useGpu_);
scaleDiff_->zeroMem(); scaleDiff_->zeroMem();
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
// propagate to param. MatrixPtr outGTmp = createSampleMatrix(outG, i, spatialDim);
MatrixPtr dataBufferTmp = const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim);
Matrix::create(dataBuffer_->getData() + i * dataDim, const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim);
channels_, const MatrixPtr inGTmp = createSampleMatrix(inG, i, spatialDim);
spatialDim, const MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim);
false,
useGpu_); channelBuffer_->sumRows(*dataTmp, 1, 0);
const MatrixPtr inValueTmp = Matrix::create(
inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
const MatrixPtr outGradTmp = Matrix::create(
outG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
MatrixPtr inGradTmp = Matrix::create(
inG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
const MatrixPtr normTmp = Matrix::create(
normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_);
channelBuffer_->sumRows(*dataBufferTmp, 1, 1);
channelBuffer_->dotDiv(*channelBuffer_, *(scale_->getW())); channelBuffer_->dotDiv(*channelBuffer_, *(scale_->getW()));
// store a / scale[i] in scaleDiff_ temporary // store a / scale[i] in scaleDiff_ temporary
scaleDiff_->add(*channelBuffer_, 1.); scaleDiff_->add(*channelBuffer_, 1.);
sampleBuffer_->dotMul(*inValueTmp, *outGradTmp); sampleBuffer_->dotMul(*inVTmp, *outGTmp);
spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.); spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.);
// scale the grad // scale the grad
inGradTmp->copyFrom(*inValueTmp); inGTmp->copyFrom(*inVTmp);
inGradTmp->mulRowVector(*spatialBuffer_); inGTmp->mulRowVector(*spatialBuffer_);
// divide by square of norm // divide by square of norm
spatialBuffer_->dotMul(*normTmp, *normTmp); spatialBuffer_->dotMul(*normTmp, *normTmp);
inGradTmp->divRowVector(*spatialBuffer_); inGTmp->divRowVector(*spatialBuffer_);
// subtract // subtract
inGradTmp->add(*outGradTmp, -1, 1); inGTmp->add(*outGTmp, -1, 1);
// divide by norm // divide by norm
inGradTmp->divRowVector(*normTmp); inGTmp->divRowVector(*normTmp);
// scale the diff // scale the diff
inGradTmp->mulColVector(*scale_->getW()); inGTmp->mulColVector(*scale_->getW());
} }
// updata scale // updata scale
if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_); if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_);
......
...@@ -80,9 +80,10 @@ public: ...@@ -80,9 +80,10 @@ public:
explicit CrossChannelNormLayer(const LayerConfig& config) explicit CrossChannelNormLayer(const LayerConfig& config)
: NormLayer(config) {} : NormLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType); void forward(PassType passType);
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback);
MatrixPtr createSampleMatrix(MatrixPtr data, size_t iter, size_t spatialDim);
MatrixPtr createSpatialMatrix(MatrixPtr data, size_t iter, size_t spatialDim);
protected: protected:
size_t channels_; size_t channels_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册