提交 b7c4b58d 编写于 作者: H hedaoyuan

Follow comments.

上级 f453b713
...@@ -189,8 +189,8 @@ public: ...@@ -189,8 +189,8 @@ public:
size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth;
size_t colWidth = outputHeight * outputWidth; size_t colWidth = outputHeight * outputWidth;
// Max col matrix height 256, Max col matrix width 1024 // Max col matrix height 256, Max col matrix width 1024
size_t stepColHeight = std::min(colHeight, (size_t)256); size_t stepColHeight = std::min(colHeight, static_cast<size_t>(256));
size_t stepColWidth = std::min(colWidth, (size_t)2048); size_t stepColWidth = std::min(colWidth, static_cast<size_t>(2048));
if (needIm2col) { if (needIm2col) {
colShape = TensorShape({inputChannels / groups_, colShape = TensorShape({inputChannels / groups_,
...@@ -278,6 +278,8 @@ public: ...@@ -278,6 +278,8 @@ public:
inputData += inputChannels * inputHeight * inputWidth; inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
} }
memory_.reset();
} }
}; };
......
...@@ -136,7 +136,7 @@ public: ...@@ -136,7 +136,7 @@ public:
(imRowIdx - paddingHeight) >= inputHeight || (imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 || (imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) { (imColIdx - paddingWidth) >= inputWidth) {
colData[colh * colWidthSize + colw] = T(0); colData[colh * colWidthSize + colw] = static_cast<T>(0);
} else { } else {
imRowIdx += c_im * inputHeight - paddingHeight; imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth; imColIdx -= paddingWidth;
......
...@@ -140,13 +140,13 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); } ...@@ -140,13 +140,13 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
template <class T> template <class T>
void TestIm2ColMobileFunctor() { void TestIm2ColMobileFunctor() {
for (size_t channels : {1, 5, 32}) { for (size_t channels : {32}) {
for (size_t inputHeight : {5, 33, 100}) { for (size_t inputHeight : {33, 100}) {
for (size_t inputWidth : {5, 32, 96}) { for (size_t inputWidth : {32, 96}) {
for (size_t filterHeight : {1, 5}) { for (size_t filterHeight : {5}) {
for (size_t filterWidth : {3, 7}) { for (size_t filterWidth : {7}) {
for (size_t stride : {1, 2}) { for (size_t stride : {2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {1}) {
for (size_t dilation : {1, 3}) { for (size_t dilation : {1, 3}) {
size_t filterSizeH = (filterHeight - 1) * dilation + 1; size_t filterSizeH = (filterHeight - 1) * dilation + 1;
size_t filterSizeW = (filterWidth - 1) * dilation + 1; size_t filterSizeW = (filterWidth - 1) * dilation + 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册