“c339946e6db2ad8ba4c2a91edfa2c63d12908135”上不存在“develop/doc/api/v2/fluid/nets.html”
提交 687b3749 编写于 作者: G guosheng

fix bug on GPU test

上级 1c007677
......@@ -43,15 +43,46 @@ void ROIPoolLayer::forward(PassType passType) {
size_t batchSize = getInput(0).getBatchSize();
size_t numROIs = getInput(1).getBatchSize();
real* bottomData = getInputValue(0)->getData();
size_t batchOffset = getInputValue(0)->getWidth();
MatrixPtr dataValue = getInputValue(0);
MatrixPtr roiValue = getInputValue(1);
resetOutput(numROIs, channels_ * pooledHeight_ * pooledWidth_);
MatrixPtr outputValue = getOutputValue();
if (useGpu_) {
MatrixPtr dataCpuBuffer;
Matrix::resizeOrCreate(dataCpuBuffer,
dataValue->getHeight(),
dataValue->getWidth(),
false,
false);
MatrixPtr roiCpuBuffer;
Matrix::resizeOrCreate(roiCpuBuffer,
roiValue->getHeight(),
roiValue->getWidth(),
false,
false);
dataCpuBuffer->copyFrom(*dataValue);
roiCpuBuffer->copyFrom(*roiValue);
dataValue = dataCpuBuffer;
roiValue = roiCpuBuffer;
MatrixPtr outputCpuBuffer;
Matrix::resizeOrCreate(outputCpuBuffer,
outputValue->getHeight(),
outputValue->getWidth(),
false,
false);
outputCpuBuffer->copyFrom(*outputValue);
outputValue = outputCpuBuffer;
}
real* bottomData = dataValue->getData();
size_t batchOffset = dataValue->getWidth();
size_t channelOffset = height_ * width_;
real* bottomROIs = getInputValue(1)->getData();
size_t roiOffset = getInputValue(1)->getWidth();
real* bottomROIs = roiValue->getData();
size_t roiOffset = roiValue->getWidth();
size_t poolChannelOffset = pooledHeight_ * pooledWidth_;
resetOutput(numROIs, channels_ * pooledHeight_ * pooledWidth_);
real* outputData = getOutputValue()->getData();
real* outputData = outputValue->getData();
Matrix::resizeOrCreate(maxIdxs_,
numROIs,
channels_ * pooledHeight_ * pooledWidth_,
......@@ -113,20 +144,52 @@ void ROIPoolLayer::forward(PassType passType) {
}
bottomROIs += roiOffset;
}
if (useGpu_) {
getOutputValue()->copyFrom(*outputValue);
}
}
void ROIPoolLayer::backward(const UpdateCallback& callback) {
real* bottomROIs = getInputValue(1)->getData();
MatrixPtr inGradValue = getInputGrad(0);
MatrixPtr outGradValue = getOutputGrad();
MatrixPtr roiValue = getInputValue(1);
if (useGpu_) {
MatrixPtr inGradCpuBuffer;
Matrix::resizeOrCreate(inGradCpuBuffer,
inGradValue->getHeight(),
inGradValue->getWidth(),
false,
false);
MatrixPtr outGradCpuBuffer;
Matrix::resizeOrCreate(outGradCpuBuffer,
outGradValue->getHeight(),
outGradValue->getWidth(),
false,
false);
MatrixPtr roiCpuBuffer;
Matrix::resizeOrCreate(roiCpuBuffer,
roiValue->getHeight(),
roiValue->getWidth(),
false,
false);
inGradCpuBuffer->copyFrom(*inGradValue);
outGradCpuBuffer->copyFrom(*outGradValue);
roiCpuBuffer->copyFrom(*roiValue);
inGradValue = inGradCpuBuffer;
outGradValue = outGradCpuBuffer;
roiValue = roiCpuBuffer;
}
real* bottomROIs = roiValue->getData();
size_t numROIs = getInput(1).getBatchSize();
size_t roiOffset = getInputValue(1)->getWidth();
MatrixPtr inGrad = getInputGrad(0);
real* inDiffData = inGrad->getData();
real* inDiffData = inGradValue->getData();
size_t batchOffset = getInputValue(0)->getWidth();
size_t channelOffset = height_ * width_;
MatrixPtr outGrad = getOutputGrad();
real* outDiffData = outGrad->getData();
real* outDiffData = outGradValue->getData();
size_t poolChannelOffset = pooledHeight_ * pooledWidth_;
real* argmaxData = maxIdxs_->getData();
......@@ -149,6 +212,10 @@ void ROIPoolLayer::backward(const UpdateCallback& callback) {
}
bottomROIs += roiOffset;
}
if (useGpu_) {
getInputGrad(0)->copyFrom(*inGradValue);
}
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册