提交 860bf192 编写于 作者: C chengduoZH

Add maxPoolIdx

上级 790379f1
...@@ -192,11 +192,10 @@ extern void hl_maxpool3D_forward(const int frameCnt, ...@@ -192,11 +192,10 @@ extern void hl_maxpool3D_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
real* maxPoolIdxData,
const int tgtStride); const int tgtStride);
extern void hl_maxpool3D_backward(const int frameCnt, extern void hl_maxpool3D_backward(const int frameCnt,
const real* inputData,
const real* outData,
const real* outGrad, const real* outGrad,
const int channels, const int channels,
const int depth, const int depth,
...@@ -217,6 +216,7 @@ extern void hl_maxpool3D_backward(const int frameCnt, ...@@ -217,6 +216,7 @@ extern void hl_maxpool3D_backward(const int frameCnt,
real scaleA, real scaleA,
real scaleB, real scaleB,
real* targetGrad, real* targetGrad,
real* maxPoolIdxData,
const int outStride); const int outStride);
extern void hl_avgpool3D_forward(const int frameCnt, extern void hl_avgpool3D_forward(const int frameCnt,
......
...@@ -106,11 +106,10 @@ inline void hl_maxpool3D_forward(const int frameCnt, ...@@ -106,11 +106,10 @@ inline void hl_maxpool3D_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
real* maxPoolIdxData,
const int tgtStride) {} const int tgtStride) {}
inline void hl_maxpool3D_backward(const int frameCnt, inline void hl_maxpool3D_backward(const int frameCnt,
const real* inputData,
const real* outData,
const real* outGrad, const real* outGrad,
const int channels, const int channels,
const int depth, const int depth,
...@@ -131,6 +130,7 @@ inline void hl_maxpool3D_backward(const int frameCnt, ...@@ -131,6 +130,7 @@ inline void hl_maxpool3D_backward(const int frameCnt,
real scaleA, real scaleA,
real scaleB, real scaleB,
real* targetGrad, real* targetGrad,
real* maxPoolIdxData,
const int outStride) {} const int outStride) {}
inline void hl_avgpool3D_forward(const int frameCnt, inline void hl_avgpool3D_forward(const int frameCnt,
......
...@@ -366,10 +366,11 @@ __global__ void KeMaxPool3DForward(const int nthreads, ...@@ -366,10 +366,11 @@ __global__ void KeMaxPool3DForward(const int nthreads,
const int strideD, const int strideD,
const int strideH, const int strideH,
const int strideW, const int strideW,
const int offsetD, const int padD,
const int offsetH, const int padH,
const int offsetW, const int padW,
real* tgtData, real* tgtData,
real* maxPoolIdxData,
const int tgtStride) { const int tgtStride) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
...@@ -378,9 +379,9 @@ __global__ void KeMaxPool3DForward(const int nthreads, ...@@ -378,9 +379,9 @@ __global__ void KeMaxPool3DForward(const int nthreads,
int pd = (index / pooledW / pooledH) % pooledD; int pd = (index / pooledW / pooledH) % pooledD;
int c = (index / pooledW / pooledH / pooledD) % channels; int c = (index / pooledW / pooledH / pooledD) % channels;
int frameNum = index / pooledW / pooledH / pooledD / channels; int frameNum = index / pooledW / pooledH / pooledD / channels;
int dstart = pd * strideD - offsetD; int dstart = pd * strideD - padD;
int hstart = ph * strideH - offsetH; int hstart = ph * strideH - padH;
int wstart = pw * strideW - offsetW; int wstart = pw * strideW - padW;
int dend = min(dstart + ksizeD, depth); int dend = min(dstart + ksizeD, depth);
int hend = min(hstart + ksizeH, height); int hend = min(hstart + ksizeH, height);
int wend = min(wstart + ksizeW, width); int wend = min(wstart + ksizeW, width);
...@@ -388,18 +389,22 @@ __global__ void KeMaxPool3DForward(const int nthreads, ...@@ -388,18 +389,22 @@ __global__ void KeMaxPool3DForward(const int nthreads,
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
real maxval = -FLT_MAX; real maxval = -FLT_MAX;
int maxIdx = -1;
inputData += (frameNum * channels + c) * depth * height * width; inputData += (frameNum * channels + c) * depth * height * width;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[(d * height + h) * width + w]) if (maxval < inputData[(d * height + h) * width + w]) {
maxval = inputData[(d * height + h) * width + w]; maxval = inputData[(d * height + h) * width + w];
maxIdx = (d * height + h) * width + w;
}
} }
} }
} }
int tgtIndex = int tgtIndex =
index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride; index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval; tgtData[tgtIndex] = maxval;
maxPoolIdxData[tgtIndex] = maxIdx;
} }
} }
...@@ -418,10 +423,11 @@ void hl_maxpool3D_forward(const int frameCnt, ...@@ -418,10 +423,11 @@ void hl_maxpool3D_forward(const int frameCnt,
const int strideD, const int strideD,
const int strideH, const int strideH,
const int strideW, const int strideW,
const int paddingD, const int padD,
const int paddingH, const int padH,
const int paddingW, const int padW,
real* tgtData, real* tgtData,
real* maxPoolIdxData,
const int tgtStride) { const int tgtStride) {
int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt; int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024; int blocks = (num_kernels + 1024 - 1) / 1024;
...@@ -443,17 +449,16 @@ void hl_maxpool3D_forward(const int frameCnt, ...@@ -443,17 +449,16 @@ void hl_maxpool3D_forward(const int frameCnt,
strideD, strideD,
strideH, strideH,
strideW, strideW,
paddingD, padD,
paddingH, padH,
paddingW, padW,
tgtData, tgtData,
maxPoolIdxData,
tgtStride); tgtStride);
CHECK_SYNC("hl_maxpool3D_forward failed"); CHECK_SYNC("hl_maxpool3D_forward failed");
} }
__global__ void KeMaxPool3DBackward(const int nthreads, __global__ void KeMaxPool3DBackward(const int nthreads,
const real* inputData,
const real* outData,
const real* outGrad, const real* outGrad,
const int channels, const int channels,
const int depth, const int depth,
...@@ -474,33 +479,35 @@ __global__ void KeMaxPool3DBackward(const int nthreads, ...@@ -474,33 +479,35 @@ __global__ void KeMaxPool3DBackward(const int nthreads,
real scaleA, real scaleA,
real scaleB, real scaleB,
real* targetGrad, real* targetGrad,
real* maxPoolIdxData,
const int outStride) { const int outStride) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
// find out the local index int offsetW = index % width;
// find out the local offset int offsetH = (index / width) % height;
int offsetW = index % width + padW; int offsetD = (index / width / height) % depth;
int offsetH = (index / width) % height + padH;
int offsetD = (index / width / height) % depth + padD;
int offsetC = (index / width / height / depth) % channels; int offsetC = (index / width / height / depth) % channels;
int frameNum = index / width / height / depth / channels; int frameNum = index / width / height / depth / channels;
int pdstart = (offsetD < sizeZ) ? 0 : (offsetD - sizeZ) / strideD + 1; int pdstart =
int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1; (offsetD + padD < sizeZ) ? 0 : (offsetD + padD - sizeZ) / strideD + 1;
int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1; int phstart =
int pdend = min(offsetD / strideD + 1, pooledD); (offsetH + padH < sizeY) ? 0 : (offsetH + padH - sizeY) / strideH + 1;
int phend = min(offsetH / strideH + 1, pooledH); int pwstart =
int pwend = min(offsetW / strideW + 1, pooledW); (offsetW + padW < sizeX) ? 0 : (offsetW + padW - sizeX) / strideW + 1;
int pdend = min((offsetD + padD) / strideD + 1, pooledD);
int phend = min((offsetH + padH) / strideH + 1, pooledH);
int pwend = min((offsetW + padW) / strideW + 1, pooledW);
real gradient = 0; real gradient = 0;
real input = inputData[index];
outData += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
outGrad += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW); outGrad += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
maxPoolIdxData +=
((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
for (int pd = pdstart; pd < pdend; ++pd) { for (int pd = pdstart; pd < pdend; ++pd) {
for (int ph = phstart; ph < phend; ++ph) { for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) { for (int pw = pwstart; pw < pwend; ++pw) {
if (input == outData[(pd * pooledH + ph) * pooledW + pw]) if (((offsetD * height + offsetH) * width + offsetW) ==
maxPoolIdxData[(pd * pooledH + ph) * pooledW + pw])
gradient += outGrad[(pd * pooledH + ph) * pooledW + pw]; gradient += outGrad[(pd * pooledH + ph) * pooledW + pw];
} }
} }
...@@ -510,8 +517,6 @@ __global__ void KeMaxPool3DBackward(const int nthreads, ...@@ -510,8 +517,6 @@ __global__ void KeMaxPool3DBackward(const int nthreads,
} }
void hl_maxpool3D_backward(const int frameCnt, void hl_maxpool3D_backward(const int frameCnt,
const real* inputData,
const real* outData,
const real* outGrad, const real* outGrad,
const int channels, const int channels,
const int depth, const int depth,
...@@ -532,13 +537,12 @@ void hl_maxpool3D_backward(const int frameCnt, ...@@ -532,13 +537,12 @@ void hl_maxpool3D_backward(const int frameCnt,
real scaleA, real scaleA,
real scaleB, real scaleB,
real* targetGrad, real* targetGrad,
real* maxPoolIdxData,
const int outStride) { const int outStride) {
int num_kernels = depth * height * width * channels * frameCnt; int num_kernels = depth * height * width * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024; int blocks = (num_kernels + 1024 - 1) / 1024;
KeMaxPool3DBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels, KeMaxPool3DBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
inputData,
outData,
outGrad, outGrad,
channels, channels,
depth, depth,
...@@ -559,6 +563,7 @@ void hl_maxpool3D_backward(const int frameCnt, ...@@ -559,6 +563,7 @@ void hl_maxpool3D_backward(const int frameCnt,
scaleA, scaleA,
scaleB, scaleB,
targetGrad, targetGrad,
maxPoolIdxData,
outStride); outStride);
CHECK_SYNC("hl_maxpool3D_backward"); CHECK_SYNC("hl_maxpool3D_backward");
} }
......
...@@ -72,9 +72,10 @@ size_t Pool3DLayer::getSize() { ...@@ -72,9 +72,10 @@ size_t Pool3DLayer::getSize() {
void Pool3DLayer::forward(PassType passType) { void Pool3DLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
const MatrixPtr& inMat = inputLayers_[0]->getOutputValue(); const MatrixPtr& inMat = inputLayers_[0]->getOutputValue();
int batchSize = inMat->getHeight(); size_t batchSize = inMat->getHeight();
int outWidth = getSize(); size_t outWidth = getSize();
resetOutput(batchSize, outWidth); resetOutput(batchSize, outWidth);
Matrix::resizeOrCreate(maxPoolIdx_, batchSize, outWidth, false, useGpu_);
const MatrixPtr outMat = getOutputValue(); const MatrixPtr outMat = getOutputValue();
if (poolType_ == "avg") { if (poolType_ == "avg") {
...@@ -97,6 +98,7 @@ void Pool3DLayer::forward(PassType passType) { ...@@ -97,6 +98,7 @@ void Pool3DLayer::forward(PassType passType) {
paddingW_); paddingW_);
} else if (poolType_ == "max") { } else if (poolType_ == "max") {
outMat->maxPool3DForward(*inMat, outMat->maxPool3DForward(*inMat,
*maxPoolIdx_,
channels_, channels_,
imgSizeD_, imgSizeD_,
imgSizeH_, imgSizeH_,
...@@ -149,9 +151,8 @@ void Pool3DLayer::backward(const UpdateCallback& callback) { ...@@ -149,9 +151,8 @@ void Pool3DLayer::backward(const UpdateCallback& callback) {
1.0, 1.0,
1.0); 1.0);
} else if (poolType_ == "max") { } else if (poolType_ == "max") {
inGradMat->maxPool3DBackward(*inMat, inGradMat->maxPool3DBackward(*outGradMat,
*outGradMat, *maxPoolIdx_,
*outMat,
imgSizeD_, imgSizeD_,
imgSizeH_, imgSizeH_,
imgSizeW_, imgSizeW_,
......
...@@ -1191,6 +1191,7 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, ...@@ -1191,6 +1191,7 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
} }
void GpuMatrix::maxPool3DForward(Matrix& inputMat, void GpuMatrix::maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels, size_t channels,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
...@@ -1210,6 +1211,7 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -1210,6 +1211,7 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat,
CHECK(inputMat.useGpu_) << "Matrix type are not correct"; CHECK(inputMat.useGpu_) << "Matrix type are not correct";
real* inputData = inputMat.getData(); real* inputData = inputMat.getData();
real* maxPoolIdxData = maxPoolIdx.getData();
size_t num = inputMat.getHeight(); size_t num = inputMat.getHeight();
size_t width = imgSizeW; size_t width = imgSizeW;
size_t height = imgSizeH; size_t height = imgSizeH;
...@@ -1237,12 +1239,12 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -1237,12 +1239,12 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat,
paddingH, paddingH,
paddingW, paddingW,
getData(), getData(),
maxPoolIdxData,
getStride()); getStride());
} }
void GpuMatrix::maxPool3DBackward(Matrix& inputMat, void GpuMatrix::maxPool3DBackward(Matrix& outGrad,
Matrix& outGrad, Matrix& maxPoolIdx,
Matrix& outV,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
...@@ -1260,26 +1262,21 @@ void GpuMatrix::maxPool3DBackward(Matrix& inputMat, ...@@ -1260,26 +1262,21 @@ void GpuMatrix::maxPool3DBackward(Matrix& inputMat,
size_t paddingW, size_t paddingW,
real scaleTargets, real scaleTargets,
real scaleOutput) { real scaleOutput) {
CHECK(inputMat.useGpu_ && outGrad.useGpu_ && outV.useGpu_) CHECK(outGrad.useGpu_ && maxPoolIdx.useGpu_) << "Matrix type are not equal";
<< "Matrix type are not equal";
real* inputData = inputMat.getData();
real* outData = outV.getData();
real* outDiff = outGrad.getData(); real* outDiff = outGrad.getData();
size_t frameNum = inputMat.getHeight(); real* maxPoolIdxData = maxPoolIdx.getData();
size_t channels = outV.getWidth() / outputD / outputH / outputW; size_t frameNum = getHeight();
size_t channels = outGrad.getWidth() / outputD / outputH / outputW;
size_t width = imgSizeW; size_t width = imgSizeW;
size_t height = imgSizeH; size_t height = imgSizeH;
size_t depth = imgSizeD; size_t depth = imgSizeD;
CHECK(depth * height * width * channels == inputMat.getWidth()); CHECK(depth * height * width * channels == getWidth());
CHECK(height_ == inputMat.getHeight());
CHECK(width_ == depth * width * height * channels); CHECK(width_ == depth * width * height * channels);
CHECK(outGrad.getHeight() == outV.getHeight() && CHECK(outGrad.getHeight() == maxPoolIdx.getHeight() &&
outGrad.getWidth() == outV.getWidth()); outGrad.getWidth() == maxPoolIdx.getWidth());
hl_maxpool3D_backward(frameNum, hl_maxpool3D_backward(frameNum,
inputData,
outData,
outDiff, outDiff,
channels, channels,
depth, depth,
...@@ -1300,6 +1297,7 @@ void GpuMatrix::maxPool3DBackward(Matrix& inputMat, ...@@ -1300,6 +1297,7 @@ void GpuMatrix::maxPool3DBackward(Matrix& inputMat,
scaleTargets, scaleTargets,
scaleOutput, scaleOutput,
getData(), getData(),
maxPoolIdxData,
outGrad.getStride()); outGrad.getStride());
} }
...@@ -2148,6 +2146,7 @@ void CpuMatrix::avgPoolBackward(Matrix& input, ...@@ -2148,6 +2146,7 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
} }
void CpuMatrix::maxPool3DForward(Matrix& inputMat, void CpuMatrix::maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels, size_t channels,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
...@@ -2166,6 +2165,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -2166,6 +2165,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
size_t paddingW) { size_t paddingW) {
real* inputData = inputMat.getData(); real* inputData = inputMat.getData();
real* outData = getData(); real* outData = getData();
real* maxPoolIdxData = maxPoolIdx.getData();
size_t num = inputMat.getHeight(); size_t num = inputMat.getHeight();
size_t inWidth = imgSizeW; size_t inWidth = imgSizeW;
size_t inHeight = imgSizeH; size_t inHeight = imgSizeH;
...@@ -2179,6 +2179,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -2179,6 +2179,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
for (size_t i = 0; i < height_; i++) { for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) { for (size_t j = 0; j < width_; j++) {
outData[(i)*outStride + j] = -(real)FLT_MAX; outData[(i)*outStride + j] = -(real)FLT_MAX;
maxPoolIdxData[(i)*outStride + j] = -1;
} }
} }
...@@ -2186,6 +2187,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -2186,6 +2187,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
for (size_t n = 0; n < num; ++n) { // frame by frame for (size_t n = 0; n < num; ++n) { // frame by frame
if (!isContiguous()) { if (!isContiguous()) {
outData = getData() + n * outStride; outData = getData() + n * outStride;
maxPoolIdxData = maxPoolIdx.getData() + n * outStride;
} }
for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t c = 0; c < channels; ++c) { // channel by channel
for (size_t pd = 0; pd < outputD; ++pd) { for (size_t pd = 0; pd < outputD; ++pd) {
...@@ -2200,6 +2202,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -2200,6 +2202,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
dstart = std::max(dstart, 0); dstart = std::max(dstart, 0);
hstart = std::max(hstart, 0); hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int maxIdx = -1;
real maxOutData = outData[(pd * outputH + ph) * outputW + pw]; real maxOutData = outData[(pd * outputH + ph) * outputW + pw];
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
...@@ -2207,24 +2210,26 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, ...@@ -2207,24 +2210,26 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
if (maxOutData < if (maxOutData <
inputData[(d * inHeight + h) * inWidth + w]) { inputData[(d * inHeight + h) * inWidth + w]) {
maxOutData = inputData[(d * inHeight + h) * inWidth + w]; maxOutData = inputData[(d * inHeight + h) * inWidth + w];
maxIdx = (d * inHeight + h) * inWidth + w;
} }
} }
} }
} }
outData[(pd * outputH + ph) * outputW + pw] = maxOutData; outData[(pd * outputH + ph) * outputW + pw] = maxOutData;
maxPoolIdxData[(pd * outputH + ph) * outputW + pw] = maxIdx;
} }
} }
} }
// compute offset // compute offset
inputData += inDepth * inHeight * inWidth; inputData += inDepth * inHeight * inWidth;
outData += outputD * outputH * outputW; outData += outputD * outputH * outputW;
maxPoolIdxData += outputD * outputH * outputW;
} }
} }
} }
void CpuMatrix::maxPool3DBackward(Matrix& image, void CpuMatrix::maxPool3DBackward(Matrix& outGrad,
Matrix& outGrad, Matrix& maxPoolIdx,
Matrix& outV,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
...@@ -2242,59 +2247,38 @@ void CpuMatrix::maxPool3DBackward(Matrix& image, ...@@ -2242,59 +2247,38 @@ void CpuMatrix::maxPool3DBackward(Matrix& image,
size_t paddingW, size_t paddingW,
real scaleTargets, real scaleTargets,
real scaleOutput) { real scaleOutput) {
size_t num = image.getHeight(); size_t num = getHeight();
size_t channels = size_t(width_ / imgSizeD / imgSizeH / imgSizeW); size_t channels = size_t(width_ / imgSizeD / imgSizeH / imgSizeW);
CHECK(image.getWidth() == imgSizeD * imgSizeH * imgSizeW * channels); CHECK(maxPoolIdx.getHeight() == outGrad.getHeight() &&
CHECK(image.getHeight() == height_ && image.getWidth() == width_); maxPoolIdx.getWidth() == outGrad.getWidth());
CHECK(outV.getHeight() == outGrad.getHeight() &&
outV.getWidth() == outGrad.getWidth());
real* tgtGrad = getData(); real* tgtGrad = getData();
real* inData = image.getData();
real* otData = outV.getData();
real* otGrad = outGrad.getData(); real* otGrad = outGrad.getData();
real* maxPoolIdxData = maxPoolIdx.getData();
size_t outStride = outV.getStride(); size_t outStride = outGrad.getStride();
; ;
for (size_t n = 0; n < num; ++n) { for (size_t n = 0; n < num; ++n) {
if (!outV.isContiguous()) { if (!outGrad.isContiguous()) {
otData = outV.getData() + n * outStride;
otGrad = outGrad.getData() + n * outStride; otGrad = outGrad.getData() + n * outStride;
maxPoolIdxData = maxPoolIdx.getData() + n * outStride;
} }
for (size_t c = 0; c < channels; ++c) { for (size_t c = 0; c < channels; ++c) {
for (size_t pd = 0; pd < outputD; ++pd) { for (size_t pd = 0; pd < outputD; ++pd) {
for (size_t ph = 0; ph < outputH; ++ph) { for (size_t ph = 0; ph < outputH; ++ph) {
for (size_t pw = 0; pw < outputW; ++pw) { for (size_t pw = 0; pw < outputW; ++pw) {
int dstart = pd * strideD - paddingD; const size_t index = (pd * outputH + ph) * outputW + pw;
int hstart = ph * strideH - paddingH; const size_t tgtIdx = static_cast<size_t>(maxPoolIdxData[index]);
int wstart = pw * strideW - paddingW; tgtGrad[tgtIdx] =
int dend = std::min(dstart + sizeZ, imgSizeD); scaleTargets * tgtGrad[tgtIdx] + scaleOutput * otGrad[index];
int hend = std::min(hstart + sizeY, imgSizeH);
int wend = std::min(wstart + sizeX, imgSizeW);
dstart = std::max(dstart, 0);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
tgtGrad[(d * imgSizeH + h) * imgSizeW + w] =
scaleTargets *
tgtGrad[(d * imgSizeH + h) * imgSizeW + w] +
scaleOutput * otGrad[(pd * outputH + ph) * outputW + pw] *
(inData[(d * imgSizeH + h) * imgSizeW + w] ==
otData[(pd * outputH + ph) * outputW + pw]);
}
}
}
} }
} }
} }
// offset // offset
inData += imgSizeD * imgSizeH * imgSizeW;
tgtGrad += imgSizeD * imgSizeH * imgSizeW; tgtGrad += imgSizeD * imgSizeH * imgSizeW;
otData += outputD * outputH * outputW;
otGrad += outputD * outputH * outputW; otGrad += outputD * outputH * outputW;
maxPoolIdxData += outputD * outputH * outputW;
} }
} }
} }
......
...@@ -933,6 +933,7 @@ public: ...@@ -933,6 +933,7 @@ public:
* in the sizeX of value * in the sizeX of value
*/ */
virtual void maxPool3DForward(Matrix& inputMat, virtual void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels, size_t channels,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
...@@ -952,9 +953,8 @@ public: ...@@ -952,9 +953,8 @@ public:
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void maxPool3DBackward(Matrix& image, virtual void maxPool3DBackward(Matrix& outGrad,
Matrix& outGrad, Matrix& maxPoolIdx,
Matrix& outV,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
...@@ -1436,6 +1436,7 @@ public: ...@@ -1436,6 +1436,7 @@ public:
size_t paddingW); size_t paddingW);
void maxPool3DForward(Matrix& inputMat, void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels, size_t channels,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
...@@ -1453,9 +1454,8 @@ public: ...@@ -1453,9 +1454,8 @@ public:
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW);
void maxPool3DBackward(Matrix& image, void maxPool3DBackward(Matrix& outGrad,
Matrix& outGrad, Matrix& maxPoolIdx,
Matrix& outV,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
...@@ -1671,6 +1671,7 @@ public: ...@@ -1671,6 +1671,7 @@ public:
size_t paddingW); size_t paddingW);
void maxPool3DForward(Matrix& inputMat, void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels, size_t channels,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
...@@ -1688,9 +1689,8 @@ public: ...@@ -1688,9 +1689,8 @@ public:
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW);
void maxPool3DBackward(Matrix& image, void maxPool3DBackward(Matrix& outGrad,
Matrix& outGrad, Matrix& maxPoolIdx,
Matrix& outV,
size_t imgSizeD, size_t imgSizeD,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册