提交 3bd699fd 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(opr): fix int8 winograd preprocess output dtype mismatch

GitOrigin-RevId: ede80d5a459ce5f82cf7e7648873e6428791568f
上级 2ae9fdef
...@@ -1745,6 +1745,62 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { ...@@ -1745,6 +1745,62 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) {
} }
#endif #endif
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
auto&& args = get_winograd_benchmark_args(3, 8);
using namespace conv_bias;
constexpr size_t RUN = 10;
Benchmarker<ConvBias> benchmark_im2col(handle());
benchmark_im2col.set_display(false);
benchmark_im2col.set_times(RUN);
benchmark_im2col.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
Benchmarker<ConvBias> benchmark_winograd(handle());
benchmark_winograd.set_display(false);
benchmark_winograd.set_times(RUN);
benchmark_winograd.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle()->create_operator<ConvBias>();
opr->param() = arg.param;
opr->deduce_layout({arg.src, dtype::Float32()},
{arg.filter, dtype::Float32()},
{arg.bias, dtype::Float32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 2.0 /
(1024 * 1024 * 1024) * 1e3;
benchmark_im2col.set_param(arg.param);
auto im2col_used =
algo_benchmark<ConvBias>(
benchmark_im2col, {arg.src, arg.filter, {}, {}, {}},
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16") /
RUN;
benchmark_winograd.set_param(arg.param);
auto winograd_used =
algo_benchmark<ConvBias>(
benchmark_winograd, {arg.src, arg.filter, {}, {}, {}},
"WINOGRAD:AARCH64_INT16X16X32_MK8_8X8:8:2") /
RUN;
printf("%s %s: im2col: %f ms %f Gflops winograd: %f ms %f GFlops "
"speedup: "
"%f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
im2col_used, computations / im2col_used, winograd_used,
computations / winograd_used, im2col_used / winograd_used);
}
}
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -736,6 +736,12 @@ std::vector<conv_bias::TestArg> get_winograd_benchmark_args(size_t kernel, ...@@ -736,6 +736,12 @@ std::vector<conv_bias::TestArg> get_winograd_benchmark_args(size_t kernel,
pack(64, 64, 123, 123, kernel, kernel / 2); pack(64, 64, 123, 123, kernel, kernel / 2);
pack(64, 24, 123, 123, kernel, kernel / 2); pack(64, 24, 123, 123, kernel, kernel / 2);
pack(24, 24, 224, 224, kernel, kernel / 2); pack(24, 24, 224, 224, kernel, kernel / 2);
//! conv in resnet18
pack(64, 64, 56, 56, kernel, kernel / 2);
pack(128, 128, 28, 28, kernel, kernel / 2);
pack(256, 256, 14, 14, kernel, kernel / 2);
pack(512, 512, 7, 7, kernel, kernel / 2);
return args; return args;
} }
......
...@@ -309,6 +309,7 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( ...@@ -309,6 +309,7 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
return _dt(1.0f) return _dt(1.0f)
cb(dtype::QuantizedS8); cb(dtype::QuantizedS8);
cb(dtype::QuantizedS16);
cb(dtype::QuantizedS32); cb(dtype::QuantizedS32);
default: default:
return DType::from_enum(enumv); return DType::from_enum(enumv);
......
...@@ -1549,6 +1549,12 @@ void RelayoutFormat::init_output_format() { ...@@ -1549,6 +1549,12 @@ void RelayoutFormat::init_output_format() {
/* f{{{ ===================== WinogradFilterPreprocess ===================== */ /* f{{{ ===================== WinogradFilterPreprocess ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WinogradFilterPreprocess); MGB_DYN_TYPE_OBJ_FINAL_IMPL(WinogradFilterPreprocess);
MEGDNN_OPR_INIT1(WinogradFilterPreprocess, "winograd_filter_preprocess") MEGDNN_OPR_INIT1(WinogradFilterPreprocess, "winograd_filter_preprocess")
void WinogradFilterPreprocess::init_output_dtype() {
TensorLayout dst;
TensorLayout src{input(0)->shape(), input(0)->dtype(), input(0)->format()};
megdnn_opr()->deduce_layout(src, dst);
output(0)->dtype(dst.dtype);
}
// f}}} // f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -637,7 +637,16 @@ MGB_DEFINE_OPR_CLASS(RelayoutFormat, ...@@ -637,7 +637,16 @@ MGB_DEFINE_OPR_CLASS(RelayoutFormat,
* *
* See docs of megdnn params for more details * See docs of megdnn params for more details
*/ */
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(WinogradFilterPreprocess); MGB_DEFINE_OPR_CLASS(WinogradFilterPreprocess,
intl::MegDNNOprWrapperFwd<megdnn::WinogradFilterPreprocess>)
public:
WinogradFilterPreprocess(VarNode* p0, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar p0, const Param& param = {},
const OperatorNodeConfig& config = {});
void init_output_dtype() override final;
};
} // opr } // opr
} // mgb } // mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册