提交 c6eb2e8d 编写于 作者: M Megvii Engine Team

refactor(dnn): optimize winograd input transpose

GitOrigin-RevId: a43077550c0e729063b0214be7f71b39ec89d710
上级 f077a529
......@@ -247,33 +247,31 @@ void StrategyHelper<
Getter<ctype, input_filter_compute_type> getter(dtype);
InputVisitor<layout, format> intput_visitor(IC);
rep(ic, IC) {
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type));
rep(i, alpha) rep(j, alpha) {
int ih = ih_start + i;
int iw = iw_start + j;
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) {
mid_buf1[i * alpha + j] = getter(
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]);
}
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type));
rep(i, alpha) rep(j, alpha) {
int ih = ih_start + i;
int iw = iw_start + j;
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) {
mid_buf1[i * alpha + j] = getter(
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]);
}
}
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, true,
false>(
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, false,
false>(
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);
rep(i, alpha) rep(j, alpha) {
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile,
unit_idx, i, j)] =
mid_buf1[i * alpha + j];
}
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, true,
false>(
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, false,
false>(
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);
rep(i, alpha) rep(j, alpha) {
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile,
unit_idx, i, j)] =
mid_buf1[i * alpha + j];
}
}
......@@ -287,7 +285,7 @@ void StrategyHelper<
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float input_filter_scale, float input_filter_rescale,
......@@ -300,49 +298,49 @@ void StrategyHelper<
OutputGetter<output_compute_type, dst_type> getter(dtype);
OutputVisitor<layout, format> output_visitor(oc_end - oc_start);
for (size_t oc = oc_start; oc < oc_end; oc++) {
/* gather */
rep(i, alpha) rep(j, alpha) {
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get(
alpha, oc - oc_start, oc, nr_units_in_tile, unit_idx, i,
j)];
}
/* A[alpha*m] M[alpha*alpha] */
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, true, false>(
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha,
alpha, m, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, false, false>(
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m,
alpha, alpha, m, m, dtype, dtype);
rep(i, m) rep(j, m) {
auto oh = oh_start + i;
auto ow = ow_start + j;
if (oh < OH && ow < OW) {
float val = mid_buf1[i * m + j];
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
val += bias[oc] * input_filter_rescale *
input_filter_rescale;
} else if (bmode == BiasMode::BIAS) {
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] *
input_filter_rescale * input_filter_rescale;
}
val = val * input_filter_scale /
(input_filter_rescale * input_filter_rescale * rescale *
rescale);
if (nonline_mode == NonlineMode::RELU) {
val = val > 0 ? val : 0;
} else if (nonline_mode == NonlineMode::SIGMOID) {
val = 1.f / (expf(-val) + 1.f);
} else if (nonline_mode == NonlineMode::H_SWISH) {
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f;
} else {
megdnn_assert(nonline_mode == NonlineMode::IDENTITY);
}
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val);
size_t oc = oc_start + oc_index;
/* gather */
rep(i, alpha) rep(j, alpha) {
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get(
alpha, oc_index, oc, nr_units_in_tile, unit_idx, i,
j)];
}
/* A[alpha*m] M[alpha*alpha] */
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, true, false>(
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha,
alpha, m, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, false, false>(
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m,
alpha, alpha, m, m, dtype, dtype);
rep(i, m) rep(j, m) {
auto oh = oh_start + i;
auto ow = ow_start + j;
if (oh < OH && ow < OW) {
float val = mid_buf1[i * m + j];
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
val += bias[oc] * input_filter_rescale *
input_filter_rescale;
} else if (bmode == BiasMode::BIAS) {
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] *
input_filter_rescale * input_filter_rescale;
}
val = val * input_filter_scale /
(input_filter_rescale * input_filter_rescale * rescale *
rescale);
if (nonline_mode == NonlineMode::RELU) {
val = val > 0 ? val : 0;
} else if (nonline_mode == NonlineMode::SIGMOID) {
val = 1.f / (expf(-val) + 1.f);
} else if (nonline_mode == NonlineMode::H_SWISH) {
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f;
} else {
megdnn_assert(nonline_mode == NonlineMode::IDENTITY);
}
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val);
}
}
};
......
......@@ -44,7 +44,7 @@ public:
input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f);
......@@ -54,7 +54,7 @@ public:
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t OH, size_t OW, size_t oc_start, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale
......
......@@ -55,7 +55,7 @@ public:
ohw_tile_size));
all_algos.emplace_back(refhold.back().get());
}
#if 0
#if 1
//! As these algos maybe very slow, it will make fastrun search slow, so
//! we disable it, but for the test of strategyhelper, we just keep it.
//! FIXME: I do not know a better way to do it.
......
......@@ -321,17 +321,10 @@ public:
"nr_tiles_in_unit: %zu TILE_SIZE:%zu",
nr_tiles_in_unit, unit_tile_size);
}
rep(unit_idx, nr_tiles_in_unit) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * Strategy::OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * Strategy::OUTPUT_BLOCK_SIZE - PW;
strategy.input(src_ptr, input_transform_buf, transform_mid_buf,
ih_start, iw_start, IH, IW, IC, unit_idx,
nr_tiles_in_unit);
}
//! BTdB
strategy.input(src_ptr, input_transform_buf, transform_mid_buf,
IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit);
rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) {
if (format == param::MatrixMul::Format::DEFAULT) {
matmul_param.A_ptr =
......@@ -368,22 +361,14 @@ public:
}
matmul_kern(matmul_param);
}
/* Y = ATmA */
rep(unit_idx, nr_tiles_in_unit) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * Strategy::OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * Strategy::OUTPUT_BLOCK_SIZE;
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit;
strategy.output(
output_transform_buf, bias_ptr, dst_ptr,
reinterpret_cast<output_compute_type*>(transform_mid_buf),
ncb_param.bias_mode, ncb_param.nonlineMode, oh_start,
ow_start, OH, OW, oc_start_idx, oc_end_idx, unit_idx,
nr_tiles_in_unit);
}
//! Y = ATmA
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit;
strategy.output(
output_transform_buf, bias_ptr, dst_ptr,
reinterpret_cast<output_compute_type*>(transform_mid_buf),
ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW,
oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit);
};
SmallVector<NCBKern> get_kerns(
......@@ -542,15 +527,16 @@ public:
size_t IC, size_t oc_start, size_t oc_end); \
void input(const stype* input, \
input_filter_compute_type* input_transform_buf, \
input_filter_compute_type* transform_mid_buf, int ih_start, \
int iw_start, size_t IH, size_t IW, size_t IC, \
size_t unit_idx, size_t nr_tiles_in_unit); \
input_filter_compute_type* transform_mid_buf, \
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \
size_t unit_start_idx, size_t nr_tiles_in_unit); \
void output(const output_compute_type* output_transform_buf, \
const output_compute_type* bias, dst_type* output, \
output_compute_type* transform_mid_buf, BiasMode bmode, \
NonlineMode nonline_mode, size_t oh_start, \
size_t ow_start, size_t OH, size_t OW, size_t oc_start, \
size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \
};
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
......
......@@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter,
transform_mid_buf, OC, IC, oc_start,
oc_end);
}
void winograd_nchw88_2x3_8x8_f::input(const float* input,
float* input_transform_buf,
float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx,
float* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) {
megdnn_assert(IC % 8 == 0);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform2X3_NCHW88::prepare<true>(
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
}
} else {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, ih_start,
iw_start, IH, IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
for (size_t ic = 0; ic < IC; ic += 8) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
InputTransform2X3_NCHW88::prepare<true>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
} else {
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
}
}
}
}
......
......@@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter,
transform_mid_buf, OC, IC, oc_start,
oc_end);
}
void winograd_nchw88_6x3_8x8_f::input(const float* input,
float* input_transform_buf,
float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx,
float* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) {
megdnn_assert(IC % 8 == 0);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform6X3_NCHW88::prepare<true>(
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
}
} else {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, ih_start,
iw_start, IH, IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
for (size_t ic = 0; ic < IC; ic += 8) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
InputTransform6X3_NCHW88::prepare<true>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
} else {
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册