提交 5e8aa333 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor winograd output transpose

GitOrigin-RevId: 6d4b225ea54a14c6c5479788b1d2b42a5b9d3cf5
上级 c6eb2e8d
...@@ -235,7 +235,7 @@ void StrategyHelper< ...@@ -235,7 +235,7 @@ void StrategyHelper<
input_filter_compute_type* input_transform_buf, input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf, input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW, int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile, size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float rescale) { float rescale) {
...@@ -284,7 +284,7 @@ void StrategyHelper< ...@@ -284,7 +284,7 @@ void StrategyHelper<
const output_compute_type* bias, dst_type* output, const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode, output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
...@@ -296,7 +296,7 @@ void StrategyHelper< ...@@ -296,7 +296,7 @@ void StrategyHelper<
output_compute_type* mid_buf1 = transform_mid_buf; output_compute_type* mid_buf1 = transform_mid_buf;
output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
OutputGetter<output_compute_type, dst_type> getter(dtype); OutputGetter<output_compute_type, dst_type> getter(dtype);
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); OutputVisitor<layout, format> output_visitor(OC);
size_t oc = oc_start + oc_index; size_t oc = oc_start + oc_index;
......
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* implied.
*/ */
#pragma once #pragma once
...@@ -44,8 +43,8 @@ public: ...@@ -44,8 +43,8 @@ public:
input_filter_compute_type* input_transform_buf, input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf, input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW, int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, size_t IC, size_t ic, size_t unit_idx,
size_t m, size_t r, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f); float rescale = 1.0f);
...@@ -54,7 +53,7 @@ public: ...@@ -54,7 +53,7 @@ public:
const output_compute_type* bias, dst_type* output, const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode, output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_index, size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale float input_filter_scale = 1.0f, // input_scale * filter_scale
......
...@@ -45,7 +45,6 @@ public: ...@@ -45,7 +45,6 @@ public:
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack();
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->algoset() == if (algo->algoset() ==
//! TODO: threre should filter MK matmul
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
continue; continue;
} }
......
...@@ -536,7 +536,6 @@ public: ...@@ -536,7 +536,6 @@ public:
NonlineMode nonline_mode, size_t OH, size_t OW, \ NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \ size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \ size_t nr_tiles_in_unit); \
}; };
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
......
...@@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 { ...@@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 {
float* output, float* transform_mid_buf, float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile, size_t oc_index, size_t unit_idx,
const DType& src_dtype, const DType& dst_dtype) { size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf); MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype); Op op(src_dtype, dst_dtype);
//! AT * m * A //! AT * m * A
size_t OCB = (oc_end - oc_start) / 8; size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { size_t oc = oc_start + oc_index;
size_t ocb = (oc - oc_start) / 8; size_t ocb = oc_index / 8;
#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \ auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \ output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ (m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8); ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb #undef cb
//! 1 1 1 0 v00 v01 v02 v03 1 0 //! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1 //! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1 //! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1 //! v30 v31 v32 v33 0 1
#define cb(m) \ #define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \ auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m; auto t1##m = v1##m - v2##m + v3##m;
UNROLL_CALL_NOWRAPPER(4, cb); UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb #undef cb
#define cb(m) \ #define cb(m) \
v##m##0 = t##m##0 + t##m##1 + t##m##2; \ v##m##0 = t##m##0 + t##m##1 + t##m##2; \
v##m##1 = t##m##1 - t##m##2 + t##m##3; v##m##1 = t##m##1 - t##m##2 + t##m##3;
UNROLL_CALL_NOWRAPPER(2, cb); UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb #undef cb
Vector<float, 8> vbias; Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc); vbias = Vector<float, 8>::load(bias + oc);
#define cb(m, n) v##m##n += vbias; #define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(2, 2, cb); UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb #undef cb
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(2, 2, cb); UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb #undef cb
} }
#define out_save(oho, owo) \ #define out_save(oho, owo) \
do { \ do { \
size_t oh = oh_start + oho; \ size_t oh = oh_start + oho; \
...@@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 { ...@@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 {
ow * 8); \ ow * 8); \
} \ } \
} while (0); } while (0);
UNROLL_CALL_RAW_D2(2, 2, out_save); UNROLL_CALL_RAW_D2(2, 2, out_save);
}
} }
}; };
#undef CONCAT #undef CONCAT
...@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, ...@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
} }
} }
void winograd_nchw88_2x3_8x8_f::output( void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf,
const float* output_transform_buf, const float* bias, float* output, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, float* transform_mid_buf, BiasMode bmode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, NonlineMode nonline_mode, size_t OH,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \ #define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__); __VA_ARGS__);
DISPATCH_CONV_WINOGRAD_BIAS( auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, size_t OC = oc_end - oc_start;
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); "Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2,
float, float, bmode, nonline_mode, output_transform_buf,
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype,
dst_dtype);
}
}
#undef cb #undef cb
} }
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
...@@ -19,10 +20,10 @@ ...@@ -19,10 +20,10 @@
#include <x86intrin.h> #include <x86intrin.h>
#ifdef WIN32CMAKE #ifdef WIN32CMAKE
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h> #include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h> #include <fmaintrin.h>
#include <smmintrin.h>
#endif #endif
#include "midout.h" #include "midout.h"
...@@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 { ...@@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 {
int ih_start, int iw_start, size_t IH, size_t IW, int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) { size_t ic, size_t IC) {
MEGDNN_MARK_USED_VAR(patch); MEGDNN_MARK_USED_VAR(patch);
size_t IW8 = IW * 8; //! For nchw88 mode size_t IW8 = IW * 8; //! For nchw88 mode
size_t iw8_start = iw_start * 8; //! For nchw88 mode size_t iw8_start = iw_start * 8; //! For nchw88 mode
size_t icb = ic / 8; size_t icb = ic / 8;
if (!(inner && ic + 8 < IC)) { if (!(inner && ic + 8 < IC)) {
...@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { ...@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) { for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) {
const float* fptr = filter + const float* fptr = filter +
(ocb * ICB + icb) * 3 * 3 * 8 * 8 + (ocb * ICB + icb) * 3 * 3 * 8 * 8 +
ic_inner * 8; ic_inner * 8;
...@@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 { ...@@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 {
float* output, float* transform_mid_buf, float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile, size_t oc_index, size_t unit_idx,
const DType& src_dtype, const DType& dst_dtype) { size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf); MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype); Op op(src_dtype, dst_dtype);
//! AT * m * A //! AT * m * A
size_t OCB = (oc_end - oc_start) / 8; size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { size_t oc = oc_start + oc_index;
size_t ocb = (oc - oc_start) / 8; size_t ocb = oc_index / 8;
#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \ auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \ output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ (m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8); ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb #undef cb
/** /**
* A * A
* *
* 1 0 0 0 0 0 * 1 0 0 0 0 0
* 1 1 1 1 1 1 * 1 1 1 1 1 1
* 1 -1 1 -1 1 -1 * 1 -1 1 -1 1 -1
* 1 2 4 8 16 32 * 1 2 4 8 16 32
* 1 -2 4 -8 16 -32 * 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125 * 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 * 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1 * 0 0.0 0 0 0 1
*/ */
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
v5subv6;
#define cb(m) \ #define cb(m) \
v1addv2 = v1##m + v2##m; \ v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \ v1subv2 = v1##m - v2##m; \
...@@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 { ...@@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 {
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
UNROLL_CALL_NOWRAPPER(8, cb); UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb #undef cb
#define cb(m) \ #define cb(m) \
...@@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 { ...@@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 {
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
UNROLL_CALL_NOWRAPPER(6, cb); UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb #undef cb
Vector<float, 8> vbias; Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc); vbias = Vector<float, 8>::load(bias + oc);
#define cb(m, n) v##m##n += vbias; #define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb
} }
if (bmode != BiasMode::BIAS) { if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(6, 6, cb); UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb #undef cb
} }
#define out_save(oho, owo) \ #define out_save(oho, owo) \
do { \ do { \
size_t oh = oh_start + oho; \ size_t oh = oh_start + oho; \
...@@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 { ...@@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 {
ow * 8); \ ow * 8); \
} \ } \
} while (0); } while (0);
UNROLL_CALL_RAW_D2(6, 6, out_save); UNROLL_CALL_RAW_D2(6, 6, out_save);
}
} }
}; };
#undef CONCAT #undef CONCAT
...@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, ...@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
megdnn_assert(IC % 8 == 0); megdnn_assert(IC % 8 == 0);
// OW = IW + 2 * PW - KERNEL_SIZE + 1 // OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf; float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha; float* patchT = transform_mid_buf + 8 * alpha * alpha;
...@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, ...@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
} }
} }
void winograd_nchw88_6x3_8x8_f::output( void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf,
const float* output_transform_buf, const float* bias, float* output, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, float* transform_mid_buf, BiasMode bmode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, NonlineMode nonline_mode, size_t OH,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \ #define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__); __VA_ARGS__);
DISPATCH_CONV_WINOGRAD_BIAS( auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, size_t OC = oc_end - oc_start;
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); "Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2,
float, float, bmode, nonline_mode, output_transform_buf,
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile,
src_dtype, dst_dtype);
}
}
#undef cb #undef cb
} }
} // namespace winograd } // namespace winograd
} // namespace arm_common } // namespace x86
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册