提交 687b3749 编写于 作者: G guosheng

fix bug on GPU test

上级 1c007677
...@@ -43,15 +43,46 @@ void ROIPoolLayer::forward(PassType passType) { ...@@ -43,15 +43,46 @@ void ROIPoolLayer::forward(PassType passType) {
size_t batchSize = getInput(0).getBatchSize(); size_t batchSize = getInput(0).getBatchSize();
size_t numROIs = getInput(1).getBatchSize(); size_t numROIs = getInput(1).getBatchSize();
real* bottomData = getInputValue(0)->getData(); MatrixPtr dataValue = getInputValue(0);
size_t batchOffset = getInputValue(0)->getWidth(); MatrixPtr roiValue = getInputValue(1);
resetOutput(numROIs, channels_ * pooledHeight_ * pooledWidth_);
MatrixPtr outputValue = getOutputValue();
if (useGpu_) {
MatrixPtr dataCpuBuffer;
Matrix::resizeOrCreate(dataCpuBuffer,
dataValue->getHeight(),
dataValue->getWidth(),
false,
false);
MatrixPtr roiCpuBuffer;
Matrix::resizeOrCreate(roiCpuBuffer,
roiValue->getHeight(),
roiValue->getWidth(),
false,
false);
dataCpuBuffer->copyFrom(*dataValue);
roiCpuBuffer->copyFrom(*roiValue);
dataValue = dataCpuBuffer;
roiValue = roiCpuBuffer;
MatrixPtr outputCpuBuffer;
Matrix::resizeOrCreate(outputCpuBuffer,
outputValue->getHeight(),
outputValue->getWidth(),
false,
false);
outputCpuBuffer->copyFrom(*outputValue);
outputValue = outputCpuBuffer;
}
real* bottomData = dataValue->getData();
size_t batchOffset = dataValue->getWidth();
size_t channelOffset = height_ * width_; size_t channelOffset = height_ * width_;
real* bottomROIs = getInputValue(1)->getData(); real* bottomROIs = roiValue->getData();
size_t roiOffset = getInputValue(1)->getWidth(); size_t roiOffset = roiValue->getWidth();
size_t poolChannelOffset = pooledHeight_ * pooledWidth_; size_t poolChannelOffset = pooledHeight_ * pooledWidth_;
resetOutput(numROIs, channels_ * pooledHeight_ * pooledWidth_); real* outputData = outputValue->getData();
real* outputData = getOutputValue()->getData();
Matrix::resizeOrCreate(maxIdxs_, Matrix::resizeOrCreate(maxIdxs_,
numROIs, numROIs,
channels_ * pooledHeight_ * pooledWidth_, channels_ * pooledHeight_ * pooledWidth_,
...@@ -113,20 +144,52 @@ void ROIPoolLayer::forward(PassType passType) { ...@@ -113,20 +144,52 @@ void ROIPoolLayer::forward(PassType passType) {
} }
bottomROIs += roiOffset; bottomROIs += roiOffset;
} }
if (useGpu_) {
getOutputValue()->copyFrom(*outputValue);
}
} }
void ROIPoolLayer::backward(const UpdateCallback& callback) { void ROIPoolLayer::backward(const UpdateCallback& callback) {
real* bottomROIs = getInputValue(1)->getData(); MatrixPtr inGradValue = getInputGrad(0);
MatrixPtr outGradValue = getOutputGrad();
MatrixPtr roiValue = getInputValue(1);
if (useGpu_) {
MatrixPtr inGradCpuBuffer;
Matrix::resizeOrCreate(inGradCpuBuffer,
inGradValue->getHeight(),
inGradValue->getWidth(),
false,
false);
MatrixPtr outGradCpuBuffer;
Matrix::resizeOrCreate(outGradCpuBuffer,
outGradValue->getHeight(),
outGradValue->getWidth(),
false,
false);
MatrixPtr roiCpuBuffer;
Matrix::resizeOrCreate(roiCpuBuffer,
roiValue->getHeight(),
roiValue->getWidth(),
false,
false);
inGradCpuBuffer->copyFrom(*inGradValue);
outGradCpuBuffer->copyFrom(*outGradValue);
roiCpuBuffer->copyFrom(*roiValue);
inGradValue = inGradCpuBuffer;
outGradValue = outGradCpuBuffer;
roiValue = roiCpuBuffer;
}
real* bottomROIs = roiValue->getData();
size_t numROIs = getInput(1).getBatchSize(); size_t numROIs = getInput(1).getBatchSize();
size_t roiOffset = getInputValue(1)->getWidth(); size_t roiOffset = getInputValue(1)->getWidth();
MatrixPtr inGrad = getInputGrad(0); real* inDiffData = inGradValue->getData();
real* inDiffData = inGrad->getData();
size_t batchOffset = getInputValue(0)->getWidth(); size_t batchOffset = getInputValue(0)->getWidth();
size_t channelOffset = height_ * width_; size_t channelOffset = height_ * width_;
MatrixPtr outGrad = getOutputGrad(); real* outDiffData = outGradValue->getData();
real* outDiffData = outGrad->getData();
size_t poolChannelOffset = pooledHeight_ * pooledWidth_; size_t poolChannelOffset = pooledHeight_ * pooledWidth_;
real* argmaxData = maxIdxs_->getData(); real* argmaxData = maxIdxs_->getData();
...@@ -149,6 +212,10 @@ void ROIPoolLayer::backward(const UpdateCallback& callback) { ...@@ -149,6 +212,10 @@ void ROIPoolLayer::backward(const UpdateCallback& callback) {
} }
bottomROIs += roiOffset; bottomROIs += roiOffset;
} }
if (useGpu_) {
getInputGrad(0)->copyFrom(*inGradValue);
}
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册