未验证 提交 73209b72 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge pull request #1505 from hjchen2/backup

Optimize general col2im to speed up transpose conv
......@@ -56,10 +56,9 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
#if 0
&& param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#if 1
&& (param->Input()->dims()[1] >= 4 ||
param->Output()->dims()[1] >= 16)
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
......
此差异已折叠。
......@@ -25,6 +25,25 @@ namespace math {
* Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 };
template <class T>
void ExtractToImg(const T *im_data, T *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImg(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImgV2(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height,
const int col_width, const int padding_h,
const int padding_w, const int stride_h, const int stride_w,
const int kh, const int kernel_w);
/*
* \brief Converts the image data of three dimensions(CHW) into a
* colData of
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册