提交 afa69024 编写于 作者: X xzl

add cuda and cpu pool_forward_with_mask impl

上级 96212132
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h" #include "hl_base.h"
/** /**
* @brief Maximum pool forward. * @brief Maximum pool forward with Mask output.
* *
* @param[in] frameCnt batch size of input image. * @param[in] frameCnt batch size of input image.
* @param[in] inputData input data. * @param[in] inputData input data.
...@@ -35,7 +35,47 @@ limitations under the License. */ ...@@ -35,7 +35,47 @@ limitations under the License. */
* @param[in] paddingW padding width. * @param[in] paddingW padding width.
* @param[out] tgtData output data. * @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples. * @param[in] tgtStride stride between output data samples.
* @param[out] maskData the location indices of select max data
* @param[in] withMask set true if output maskData
*/
extern void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
const int height,
const int width,
const int pooledH,
const int pooledW,
const int sizeX,
const int sizeY,
const int strideH,
const int strideW,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride,
real* maskData,
bool withMask);
/**
* @brief Maximum pool forward.
* *
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeX width of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
* @param[out] maskData the location indices of select max data
* @param[in] withMask set true if output maskData
*/ */
extern void hl_maxpool_forward(const int frameCnt, extern void hl_maxpool_forward(const int frameCnt,
const real* inputData, const real* inputData,
......
...@@ -33,6 +33,24 @@ inline void hl_maxpool_forward(const int frameCnt, ...@@ -33,6 +33,24 @@ inline void hl_maxpool_forward(const int frameCnt,
real* tgtData, real* tgtData,
const int tgtStride) {} const int tgtStride) {}
inline void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
const int height,
const int width,
const int pooledH,
const int pooledW,
const int sizeX,
const int sizeY,
const int strideH,
const int strideW,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride,
real* MaskData,
bool withMask) {}
inline void hl_maxpool_backward(const int frameCnt, inline void hl_maxpool_backward(const int frameCnt,
const real* inputData, const real* inputData,
const real* outData, const real* outData,
......
...@@ -31,7 +31,9 @@ __global__ void KeMaxPoolForward(const int nthreads, ...@@ -31,7 +31,9 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH, const int offsetH,
const int offsetW, const int offsetW,
real* tgtData, real* tgtData,
const int tgtStride) { const int tgtStride,
real* maskData,
bool withMask) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) { if (index < nthreads) {
int pw = index % pooledW; int pw = index % pooledW;
...@@ -45,16 +47,22 @@ __global__ void KeMaxPoolForward(const int nthreads, ...@@ -45,16 +47,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
real maxval = -FLT_MAX; real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width; inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w]) if (maxval < inputData[h * width + w]) {
maxval = inputData[h * width + w]; maxval = inputData[h * width + w];
max_index = h * width + w;
}
} }
} }
int tgtIndex = int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride; index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval; tgtData[tgtIndex] = maxval;
if (withMask) {
maskData[tgtIndex] = max_index;
}
} }
} }
...@@ -92,7 +100,51 @@ void hl_maxpool_forward(const int frameCnt, ...@@ -92,7 +100,51 @@ void hl_maxpool_forward(const int frameCnt,
paddingH, paddingH,
paddingW, paddingW,
tgtData, tgtData,
tgtStride); tgtStride,
NULL,
false);
CHECK_SYNC("hl_maxpool_forward failed");
}
void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
const int height,
const int width,
const int pooledH,
const int pooledW,
const int sizeX,
const int sizeY,
const int strideH,
const int strideW,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride,
real* maskData,
bool withMask) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KeMaxPoolForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
inputData,
channels,
height,
width,
pooledH,
pooledW,
sizeX,
sizeY,
strideH,
strideW,
paddingH,
paddingW,
tgtData,
tgtStride,
maskData,
withMask);
CHECK_SYNC("hl_maxpool_forward failed"); CHECK_SYNC("hl_maxpool_forward failed");
} }
......
...@@ -1029,14 +1029,51 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -1029,14 +1029,51 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW) { size_t paddingW) {
maxPoolForward(inputMat,
imgSizeH,
imgSizeW,
channels,
sizeX,
sizeY,
strideH,
strideW,
outputH,
outputW,
paddingH,
paddingW,
NULL,
false);
}
void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask) {
CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal";
real* inputData = inputMat.getData(); real* inputData = inputMat.getData();
real* maskData = NULL;
size_t frameNum = inputMat.getHeight(); size_t frameNum = inputMat.getHeight();
CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth());
CHECK(height_ == inputMat.getHeight()); CHECK(height_ == inputMat.getHeight());
CHECK(width_ == outputH * outputW * channels); CHECK(width_ == outputH * outputW * channels);
if (withMask) {
CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal";
CHECK(outputH * outputW * channels == maskMatP->getWidth());
maskData = maskMatP->getData();
}
hl_maxpool_forward(frameNum, hl_maxpool_forward(frameNum,
inputData, inputData,
channels, channels,
...@@ -1051,7 +1088,9 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -1051,7 +1088,9 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
paddingH, paddingH,
paddingW, paddingW,
data_, data_,
getStride()); getStride(),
maskData,
withMask);
} }
void GpuMatrix::maxPoolBackward(Matrix& inputMat, void GpuMatrix::maxPoolBackward(Matrix& inputMat,
...@@ -1974,8 +2013,39 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -1974,8 +2013,39 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW) { size_t paddingW) {
maxPoolForward(inputMat,
imgSizeH,
imgSizeW,
channels,
sizeX,
sizeY,
strideH,
strideW,
outputH,
outputW,
paddingH,
paddingW,
NULL,
false);
}
void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask) {
real* inputData = inputMat.getData(); real* inputData = inputMat.getData();
real* outData = data_; real* outData = data_;
real* maskData = NULL;
size_t num = inputMat.getHeight(); size_t num = inputMat.getHeight();
size_t inLength = imgSizeH * imgSizeW; size_t inLength = imgSizeH * imgSizeW;
size_t outLength = outputH * outputW; size_t outLength = outputH * outputW;
...@@ -1984,6 +2054,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -1984,6 +2054,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
CHECK_EQ(channels * outLength, this->getWidth()); CHECK_EQ(channels * outLength, this->getWidth());
size_t outStride = getStride(); size_t outStride = getStride();
if (withMask) {
maskData = maskMatP->getData();
CHECK_EQ(channels * outLength, maskMatP->getWidth());
}
/* initialize the data_ */ /* initialize the data_ */
for (size_t i = 0; i < height_; i++) { for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) { for (size_t j = 0; j < width_; j++) {
...@@ -2005,10 +2080,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -2005,10 +2080,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
int wstart = pw * strideW - paddingW; int wstart = pw * strideW - paddingW;
int wend = std::min(wstart + sizeX, imgSizeW); int wend = std::min(wstart + sizeX, imgSizeW);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
for (int h = hstart; h < hend; ++h) { if (!withMask) {
for (int w = wstart; w < wend; ++w) { for (int h = hstart; h < hend; ++h) {
outData[ph * outputW + pw] = std::max( for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw], inputData[h * imgSizeW + w]); outData[ph * outputW + pw] = std::max(
outData[ph * outputW + pw], inputData[h * imgSizeW + w]);
}
}
} else {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) {
outData[ph * outputW + pw] = inputData[h * imgSizeW + w];
maskData[ph * outputW + pw] = h * imgSizeW + w;
}
}
} }
} }
} }
...@@ -2016,6 +2102,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -2016,6 +2102,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
// compute offset // compute offset
inputData += inLength; inputData += inLength;
outData += outLength; outData += outLength;
if (withMask) maskData += outLength;
} }
} }
} }
......
...@@ -861,7 +861,7 @@ public: ...@@ -861,7 +861,7 @@ public:
/** /**
* Pooling forward operation, pick out the largest element * Pooling forward operation, pick out the largest element
* in the sizeX of value * in the sizeX of value.
*/ */
virtual void maxPoolForward(Matrix& inputMat, virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH, size_t imgSizeH,
...@@ -878,6 +878,28 @@ public: ...@@ -878,6 +878,28 @@ public:
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value, if set withMask true, it will
* also caculate the location indices.
*/
virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask) {
LOG(FATAL) << "Not implemeted";
}
/// Pooling backward operation. /// Pooling backward operation.
virtual void maxPoolBackward(Matrix& image, virtual void maxPoolBackward(Matrix& image,
size_t imgSizeH, size_t imgSizeH,
...@@ -1428,6 +1450,21 @@ public: ...@@ -1428,6 +1450,21 @@ public:
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask);
void maxPoolBackward(Matrix& image, void maxPoolBackward(Matrix& image,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
...@@ -1699,6 +1736,21 @@ public: ...@@ -1699,6 +1736,21 @@ public:
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask);
void maxPoolBackward(Matrix& image, void maxPoolBackward(Matrix& image,
size_t imgSizeH, size_t imgSizeH,
size_t imgSizeW, size_t imgSizeW,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册