提交 a9c578e8 编写于 作者: E eclipsess

fix template class double

上级 ece4316e
......@@ -14,7 +14,9 @@ limitations under the License. */
#include "operators/math/im2col.h"
#include <vector>
#ifdef __ARM_NEON
#include "arm_neon.h"
#endif
#include "common/types.h"
namespace paddle_mobile {
namespace operators {
......@@ -67,7 +69,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
int channels_col = im_channels * filter_height * filter_width;
const T *im_data = im.data<T>();
T *col_data = col->data<T>();
#ifdef __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
......@@ -408,6 +410,26 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
}
}
}
#else
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
: im_data[im_idx];
}
}
}
#endif
}
};
......@@ -480,7 +502,7 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
};
template class Im2ColFunctor<ColFormat::kCFO, CPU, float>;
template class Im2ColFunctor<ColFormat::kCFO, CPU, double>;
// template class Im2ColFunctor<ColFormat::kCFO, CPU, double>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, float>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, double>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册