提交 dc6f89f2 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor winograd strategy helper

GitOrigin-RevId: ecc2b15df995a526688d3a5593d8db6767c0c717
上级 e1b2d31d
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -28,8 +29,8 @@ using BiasMode = ConvBiasForward::BiasMode; ...@@ -28,8 +29,8 @@ using BiasMode = ConvBiasForward::BiasMode;
*/ */
template <typename ctype, typename dst_type, typename input_filter_compute_type, template <typename ctype, typename dst_type, typename input_filter_compute_type,
typename output_compute_type, typename output_compute_type,
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, param::ConvBias::Format layout = param::ConvBias::Format::NCHW,
typename enable = void> param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
class StrategyHelper { class StrategyHelper {
public: public:
static void filter(const ctype* filter, static void filter(const ctype* filter,
...@@ -61,47 +62,6 @@ public: ...@@ -61,47 +62,6 @@ public:
float rescale = 1.0f); float rescale = 1.0f);
}; };
/**
* \brief Strategy helper, contains some helper function for debug kernel
* implementation
*
* \warning The layout should be NCHW88
*/
template <typename ctype, typename dst_type, typename input_filter_compute_type,
typename output_compute_type,
param::MatrixMul::Format format = param::MatrixMul::Format::MK8,
typename enable = void>
class StrategyHelperNchwxx {
public:
static void filter(const ctype* filter,
input_filter_compute_type* filter_transform_buf,
input_filter_compute_type* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start, size_t oc_end, size_t m,
size_t r, const std::vector<float>& interp_points,
DType dtype, float rescale = 1.0f);
static void input(const ctype* input,
input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f);
static void
output(const output_compute_type* output_transform_buf,
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale
float input_filter_rescale = 1.0f, // input_rescale * filter_rescale
float rescale = 1.0f);
};
} // namespace winograd } // namespace winograd
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,13 +6,14 @@ ...@@ -6,13 +6,14 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/fallback/conv_bias/winograd/strategy.h" #include "src/fallback/conv_bias/winograd/strategy.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {
...@@ -60,7 +61,7 @@ void winograd_2x3_4x4_f::filter(const float* filter, ...@@ -60,7 +61,7 @@ void winograd_2x3_4x4_f::filter(const float* filter,
float* transform_mid_buf, size_t OC, size_t IC, float* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end) { size_t oc_start, size_t oc_end) {
::megdnn::winograd::StrategyHelper< ::megdnn::winograd::StrategyHelper<
float, float, float, float, float, float, float, float, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK4>::filter(filter, filter_transform_buf, param::MatrixMul::Format::MK4>::filter(filter, filter_transform_buf,
transform_mid_buf, OC, IC, transform_mid_buf, OC, IC,
oc_start, oc_end, oc_start, oc_end,
...@@ -73,11 +74,15 @@ void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, ...@@ -73,11 +74,15 @@ void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, int ih_start, float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW, size_t IC, int iw_start, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) { size_t unit_idx, size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<float, float, float, float, ::megdnn::winograd::StrategyHelper<
param::MatrixMul::Format::MK4>:: float, float, float, float, param::ConvBias::Format::NCHW,
input(input, input_transform_buf, transform_mid_buf, ih_start, param::MatrixMul::Format::MK4>::input(input, input_transform_buf,
iw_start, IH, IW, IC, unit_idx, nr_units_in_tile, transform_mid_buf, ih_start,
OUTPUT_BLOCK_SIZE, KERNEL_SIZE, {0, 1, -1}, src_dtype); iw_start, IH, IW, IC,
unit_idx, nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
src_dtype);
} }
void winograd_2x3_4x4_f::output(const float* output_transform_buf, void winograd_2x3_4x4_f::output(const float* output_transform_buf,
...@@ -87,16 +92,19 @@ void winograd_2x3_4x4_f::output(const float* output_transform_buf, ...@@ -87,16 +92,19 @@ void winograd_2x3_4x4_f::output(const float* output_transform_buf,
size_t ow_start, size_t OH, size_t OW, size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t unit_idx, size_t oc_start, size_t oc_end, size_t unit_idx,
size_t nr_units_in_tile) { size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<float, float, float, float, ::megdnn::winograd::StrategyHelper<
param::MatrixMul::Format::MK4>:: float, float, float, float, param::ConvBias::Format::NCHW,
output(output_transform_buf, bias, output, transform_mid_buf, bmode, param::MatrixMul::Format::MK4>::output(output_transform_buf, bias,
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, output, transform_mid_buf,
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, bmode, nonline_mode,
{0, 1, -1}, dst_dtype); oh_start, ow_start, OH, OW,
oc_start, oc_end, unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
dst_dtype);
} }
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8)
void winograd_2x3_1x1_qs8::filter(const int8_t* filter, void winograd_2x3_1x1_qs8::filter(const int8_t* filter,
...@@ -136,7 +144,6 @@ void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, ...@@ -136,7 +144,6 @@ void winograd_2x3_1x1_qs8::output(const int* output_transform_buf,
{0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f); {0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f);
} }
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8)
void winograd_2x3_8x8_qs8::filter(const int8_t* filter, void winograd_2x3_8x8_qs8::filter(const int8_t* filter,
...@@ -144,7 +151,7 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter, ...@@ -144,7 +151,7 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter,
int16_t* transform_mid_buf, size_t OC, int16_t* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start, size_t oc_end) { size_t IC, size_t oc_start, size_t oc_end) {
::megdnn::winograd::StrategyHelper< ::megdnn::winograd::StrategyHelper<
int8_t, int8_t, int16_t, int, int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK8>::filter(filter, filter_transform_buf, param::MatrixMul::Format::MK8>::filter(filter, filter_transform_buf,
transform_mid_buf, OC, IC, transform_mid_buf, OC, IC,
oc_start, oc_end, oc_start, oc_end,
...@@ -158,11 +165,15 @@ void winograd_2x3_8x8_qs8::input(const int8_t* input, ...@@ -158,11 +165,15 @@ void winograd_2x3_8x8_qs8::input(const int8_t* input,
int16_t* transform_mid_buf, int ih_start, int16_t* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW, size_t IC, int iw_start, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) { size_t unit_idx, size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int, ::megdnn::winograd::StrategyHelper<
param::MatrixMul::Format::MK8>:: int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
input(input, input_transform_buf, transform_mid_buf, ih_start, param::MatrixMul::Format::MK8>::input(input, input_transform_buf,
iw_start, IH, IW, IC, unit_idx, nr_units_in_tile, transform_mid_buf, ih_start,
OUTPUT_BLOCK_SIZE, KERNEL_SIZE, {0, 1, -1}, src_dtype, 1.0f); iw_start, IH, IW, IC,
unit_idx, nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
src_dtype, 1.0f);
} }
void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, void winograd_2x3_8x8_qs8::output(const int* output_transform_buf,
...@@ -180,13 +191,19 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, ...@@ -180,13 +191,19 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf,
megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16);
scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale;
} }
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int, ::megdnn::winograd::StrategyHelper<
param::MatrixMul::Format::MK8>:: int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
output(output_transform_buf, bias, output, transform_mid_buf, bmode, param::MatrixMul::Format::MK8>::output(output_transform_buf, bias,
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, output, transform_mid_buf,
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, bmode, nonline_mode,
{0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, oh_start, ow_start, OH, OW,
1.0f); oc_start, oc_end, unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
dst_dtype,
scale_input * scale_filter,
2.0f, 1.0f);
} }
} // namespace winograd } // namespace winograd
......
...@@ -28,6 +28,7 @@ MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 1, 1, ...@@ -28,6 +28,7 @@ MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 1, 1,
MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 8, 8, MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 8, 8,
winograd_2x3_8x8_qs8) winograd_2x3_8x8_qs8)
} }
} // namespace fallback } // namespace fallback
} // namespace megdnn } // namespace megdnn
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/naive/winograd_filter_preprocess/opr_impl.h" #include "src/naive/winograd_filter_preprocess/opr_impl.h"
...@@ -49,17 +50,16 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, ...@@ -49,17 +50,16 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
size_t m = param().output_block_size; size_t m = param().output_block_size;
bool execed = false; bool execed = false;
#define cb(_ctype, _dst_type, _input_filter_compute_type, \
_output_compute_type, _format, rescale) \ #define cb(_ctype, _dst_type, _input_filter_compute_type, \
if (param().format == _format) { \ _output_compute_type, _format, rescale) \
return winograd::StrategyHelper< \ if (param().format == _format) { \
_ctype, _dst_type, _input_filter_compute_type, \ return winograd::StrategyHelper< \
_output_compute_type, _format>::filter(src_ptr, dst_ptr, \ _ctype, _dst_type, _input_filter_compute_type, \
workspace_ptr, OC, IC, \ _output_compute_type, param::ConvBias::Format::NCHW, \
0, OC, m, FW, \ _format>::filter(src_ptr, dst_ptr, workspace_ptr, OC, IC, 0, \
interp_points, \ OC, m, FW, interp_points, src.layout.dtype, \
src.layout.dtype, \ rescale); \
rescale); \
} }
#define DISPATCH_FORMAT_MK4(_ctype, _dst_type, _input_filter_compute_type, \ #define DISPATCH_FORMAT_MK4(_ctype, _dst_type, _input_filter_compute_type, \
...@@ -110,8 +110,9 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, ...@@ -110,8 +110,9 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
DISPATCH_KERNEL(dt_float16, dt_float16, dt_float16, dt_float16, \ DISPATCH_KERNEL(dt_float16, dt_float16, dt_float16, dt_float16, \
DISPATCH_FORMAT_MK8, 1.0f, _midout_tag, 2); \ DISPATCH_FORMAT_MK8, 1.0f, _midout_tag, 2); \
}) })
//! normal nchw mode
if (src.layout.ndim <= 5) { if (src.layout.ndim <= 5) {
//! dispatch_dtype with consider layout and format.
if (FW == 3) { if (FW == 3) {
if (m == 2) { if (m == 2) {
std::vector<float> interp_points = {0, 1, -1}; std::vector<float> interp_points = {0, 1, -1};
...@@ -131,22 +132,20 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, ...@@ -131,22 +132,20 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
DISPATCH_DTYPE(3); DISPATCH_DTYPE(3);
} }
} }
}
#undef cb #undef cb
#undef DISPATCH_FORMAT_MK4 #undef DISPATCH_FORMAT_MK4
#undef DISPATCH_FORMAT_MK8 #undef DISPATCH_FORMAT_MK8
#undef DISPATCH_DTYPE #undef DISPATCH_DTYPE
#define cb(_ctype, _dst_type, _input_filter_compute_type, \ } else {
_output_compute_type, _format, rescale) \ #define cb(_ctype, _dst_type, _input_filter_compute_type, \
if (param().format == _format) { \ _output_compute_type, _format, rescale) \
return winograd::StrategyHelperNchwxx< \ if (param().format == _format) { \
_ctype, _dst_type, _input_filter_compute_type, \ return winograd::StrategyHelper< \
_output_compute_type, _format>::filter(src_ptr, dst_ptr, \ _ctype, _dst_type, _input_filter_compute_type, \
workspace_ptr, OC, IC, \ _output_compute_type, param::ConvBias::Format::NCHW88, \
0, OC, m, FW, \ _format>::filter(src_ptr, dst_ptr, workspace_ptr, OC, IC, 0, \
interp_points, \ OC, m, FW, interp_points, src.layout.dtype, \
src.layout.dtype, \ rescale); \
rescale); \
} }
#define DISPATCH_FORMAT_MK8(_ctype, _dst_type, _input_filter_compute_type, \ #define DISPATCH_FORMAT_MK8(_ctype, _dst_type, _input_filter_compute_type, \
...@@ -159,8 +158,6 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, ...@@ -159,8 +158,6 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
DISPATCH_KERNEL(dt_float32, dt_float32, dt_float32, dt_float32, \ DISPATCH_KERNEL(dt_float32, dt_float32, dt_float32, dt_float32, \
DISPATCH_FORMAT_MK8, 1.0f, _midout_tag, 0); \ DISPATCH_FORMAT_MK8, 1.0f, _midout_tag, 0); \
} }
//! nchwxx mode
else {
megdnn_assert(src.layout.ndim == 6 || src.layout.ndim == 7); megdnn_assert(src.layout.ndim == 6 || src.layout.ndim == 7);
if (FW == 3) { if (FW == 3) {
if (m == 2) { if (m == 2) {
...@@ -171,11 +168,11 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, ...@@ -171,11 +168,11 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
DISPATCH_DTYPE(5); DISPATCH_DTYPE(5);
} }
} }
}
#undef cb #undef cb
#undef DISPATCH_FORMAT_MK8 #undef DISPATCH_FORMAT_MK8
#undef DISPATCH_KERNEL #undef DISPATCH_KERNEL
#undef DISPATCH_DTYPE #undef DISPATCH_DTYPE
}
megdnn_assert(execed, megdnn_assert(execed,
"Unsupport winograd filter preprocess. m: %zu src: %s", m, "Unsupport winograd filter preprocess. m: %zu src: %s", m,
src.layout.to_string().c_str()); src.layout.to_string().c_str());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册