提交 784e5940 编写于 作者: H hedaoyuan

Bug fix of Im2ColMobileFunctor.

上级 ed0a564c
...@@ -98,58 +98,6 @@ public: ...@@ -98,58 +98,6 @@ public:
int dilationWidth = 1); int dilationWidth = 1);
}; };
#if 0
template <class T>
class Im2ColMobileFunctor {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int dilationHeight,
int dilationWidth,
int colHeightStart,
int colHeightSize,
int colWidthStart,
int colWidthSize) {
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputWidth = colShape[4];
for (int colh = 0; colh < colHeightSize; colh++) {
int wOffset = (colHeightStart + colh) % filterWidth;
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight;
int c_im = (colHeightStart + colh) / filterWidth / filterHeight;
for (int colw = 0; colw < colWidthSize; colw++) {
int h = (colWidthStart + colw) / outputWidth;
int w = (colWidthStart + colw) % outputWidth;
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[colh * colWidthSize + colw] = static_cast<T>(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[colh * colWidthSize + colw] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
};
#endif
template <class T> template <class T>
class Im2ColMobileFunctor { class Im2ColMobileFunctor {
public: public:
...@@ -178,12 +126,14 @@ public: ...@@ -178,12 +126,14 @@ public:
T* dstData = colData + oh * outputWidth; T* dstData = colData + oh * outputWidth;
for (int fh = 0; fh < filterHeight; fh++) { for (int fh = 0; fh < filterHeight; fh++) {
for (int fw = 0; fw < filterWidth; fw++) { for (int fw = 0; fw < filterWidth; fw++) {
int imRowIdx = (oh + colOffset) * strideHeight + fh - paddingHeight; int imRowIdx = (oh + colOffset) * strideHeight +
fh * dilationHeight - paddingHeight;
if (imRowIdx < 0 || imRowIdx >= inputHeight) { if (imRowIdx < 0 || imRowIdx >= inputHeight) {
memset(dstData, 0, outputWidth * sizeof(T)); memset(dstData, 0, outputWidth * sizeof(T));
} else { } else {
for (int ow = 0; ow < outputWidth; ow++) { for (int ow = 0; ow < outputWidth; ow++) {
int imColIdx = ow * strideWidth + fw - paddingWidth; int imColIdx =
ow * strideWidth + fw * dilationWidth - paddingWidth;
if (imColIdx < 0 || imColIdx >= inputWidth) { if (imColIdx < 0 || imColIdx >= inputWidth) {
dstData[ow] = T(0); dstData[ow] = T(0);
} else { } else {
......
...@@ -202,10 +202,10 @@ void TestIm2ColMobileFunctor() { ...@@ -202,10 +202,10 @@ void TestIm2ColMobileFunctor() {
padding, padding,
dilation, dilation,
dilation, dilation,
channels,
0, 0,
height, outputHeight,
0, outputHeight * outputWidth);
width);
autotest::TensorCheckEqual(*output1, *output2); autotest::TensorCheckEqual(*output1, *output2);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册