提交 afa69024 编写于 作者: X xzl

add cuda and cpu pool_forward_with_mask impl

上级 96212132
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h"
/**
* @brief Maximum pool forward.
* @brief Maximum pool forward with Mask output.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
......@@ -35,7 +35,47 @@ limitations under the License. */
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
* @param[out] maskData the location indices of select max data
* @param[in] withMask set true if output maskData
*/
extern void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
const int height,
const int width,
const int pooledH,
const int pooledW,
const int sizeX,
const int sizeY,
const int strideH,
const int strideW,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride,
real* maskData,
bool withMask);
/**
* @brief Maximum pool forward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeX width of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
* @param[out] maskData the location indices of select max data
* @param[in] withMask set true if output maskData
*/
extern void hl_maxpool_forward(const int frameCnt,
const real* inputData,
......
......@@ -33,6 +33,24 @@ inline void hl_maxpool_forward(const int frameCnt,
real* tgtData,
const int tgtStride) {}
inline void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
const int height,
const int width,
const int pooledH,
const int pooledW,
const int sizeX,
const int sizeY,
const int strideH,
const int strideW,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride,
real* MaskData,
bool withMask) {}
inline void hl_maxpool_backward(const int frameCnt,
const real* inputData,
const real* outData,
......
......@@ -31,7 +31,9 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH,
const int offsetW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData,
bool withMask) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
......@@ -45,16 +47,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0);
wstart = max(wstart, 0);
real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w])
if (maxval < inputData[h * width + w]) {
maxval = inputData[h * width + w];
max_index = h * width + w;
}
}
}
int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval;
if (withMask) {
maskData[tgtIndex] = max_index;
}
}
}
......@@ -92,7 +100,51 @@ void hl_maxpool_forward(const int frameCnt,
paddingH,
paddingW,
tgtData,
tgtStride);
tgtStride,
NULL,
false);
CHECK_SYNC("hl_maxpool_forward failed");
}
void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
const int height,
const int width,
const int pooledH,
const int pooledW,
const int sizeX,
const int sizeY,
const int strideH,
const int strideW,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride,
real* maskData,
bool withMask) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KeMaxPoolForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
inputData,
channels,
height,
width,
pooledH,
pooledW,
sizeX,
sizeY,
strideH,
strideW,
paddingH,
paddingW,
tgtData,
tgtStride,
maskData,
withMask);
CHECK_SYNC("hl_maxpool_forward failed");
}
......
......@@ -1029,14 +1029,51 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputW,
size_t paddingH,
size_t paddingW) {
maxPoolForward(inputMat,
imgSizeH,
imgSizeW,
channels,
sizeX,
sizeY,
strideH,
strideW,
outputH,
outputW,
paddingH,
paddingW,
NULL,
false);
}
void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask) {
CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal";
real* inputData = inputMat.getData();
real* maskData = NULL;
size_t frameNum = inputMat.getHeight();
CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth());
CHECK(height_ == inputMat.getHeight());
CHECK(width_ == outputH * outputW * channels);
if (withMask) {
CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal";
CHECK(outputH * outputW * channels == maskMatP->getWidth());
maskData = maskMatP->getData();
}
hl_maxpool_forward(frameNum,
inputData,
channels,
......@@ -1051,7 +1088,9 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
paddingH,
paddingW,
data_,
getStride());
getStride(),
maskData,
withMask);
}
void GpuMatrix::maxPoolBackward(Matrix& inputMat,
......@@ -1974,8 +2013,39 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputW,
size_t paddingH,
size_t paddingW) {
maxPoolForward(inputMat,
imgSizeH,
imgSizeW,
channels,
sizeX,
sizeY,
strideH,
strideW,
outputH,
outputW,
paddingH,
paddingW,
NULL,
false);
}
void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask) {
real* inputData = inputMat.getData();
real* outData = data_;
real* maskData = NULL;
size_t num = inputMat.getHeight();
size_t inLength = imgSizeH * imgSizeW;
size_t outLength = outputH * outputW;
......@@ -1984,6 +2054,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
CHECK_EQ(channels * outLength, this->getWidth());
size_t outStride = getStride();
if (withMask) {
maskData = maskMatP->getData();
CHECK_EQ(channels * outLength, maskMatP->getWidth());
}
/* initialize the data_ */
for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) {
......@@ -2005,10 +2080,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
int wstart = pw * strideW - paddingW;
int wend = std::min(wstart + sizeX, imgSizeW);
wstart = std::max(wstart, 0);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw] = std::max(
outData[ph * outputW + pw], inputData[h * imgSizeW + w]);
if (!withMask) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw] = std::max(
outData[ph * outputW + pw], inputData[h * imgSizeW + w]);
}
}
} else {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) {
outData[ph * outputW + pw] = inputData[h * imgSizeW + w];
maskData[ph * outputW + pw] = h * imgSizeW + w;
}
}
}
}
}
......@@ -2016,6 +2102,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
// compute offset
inputData += inLength;
outData += outLength;
if (withMask) maskData += outLength;
}
}
}
......
......@@ -861,7 +861,7 @@ public:
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value
* in the sizeX of value.
*/
virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
......@@ -878,6 +878,28 @@ public:
LOG(FATAL) << "Not implemeted";
}
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value, if set withMask true, it will
* also caculate the location indices.
*/
virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask) {
LOG(FATAL) << "Not implemeted";
}
/// Pooling backward operation.
virtual void maxPoolBackward(Matrix& image,
size_t imgSizeH,
......@@ -1428,6 +1450,21 @@ public:
size_t paddingH,
size_t paddingW);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask);
void maxPoolBackward(Matrix& image,
size_t imgSizeH,
size_t imgSizeW,
......@@ -1699,6 +1736,21 @@ public:
size_t paddingH,
size_t paddingW);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW,
MatrixPtr maskMatP,
bool withMask);
void maxPoolBackward(Matrix& image,
size_t imgSizeH,
size_t imgSizeW,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册