提交 bba04f02 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/gopt): add fusion support for conv, astype(s4) and reformat

GitOrigin-RevId: 6329ca2c5fe85dc4f7c94610a293ab506a3c2945
上级 66f70578
......@@ -3550,6 +3550,35 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
return y2.node();
};
auto nchw42nhwc = [](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0);
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers,
&nchw42nchw](
OperatorNodeBase* opr) {
......@@ -3721,6 +3750,106 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
return true;
};
auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 =
try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 3 && param.pattern[3] == 4 &&
param.pattern[4] == 2 && param.pattern[5] == 5 &&
shuffle->output(0)->shape()[5] == 4 &&
shuffle->output(0)->shape()[4] == 16;
if (!is_nchw42nchw64)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
auto typecvt =
try_cast_as_op<opr::TypeCvt>(reshape2->input(0)->owner_opr());
if (typecvt == nullptr)
return false;
auto in_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
printf("%s, %s\n", in_dtype.name(), out_dtype.name());
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm);
if (!is_s82s4)
return false;
opr_set.insert(typecvt);
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = nchw42nhwc(bias);
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype});
auto new_var = nhwc2nchw64(conv_bias_shuffle.node());
rewriter.replace_var(
opr->output(0), new_var,
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NCHW64)"));
return true;
};
auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4](
OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
......@@ -3805,12 +3934,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt,
&try_conv_reformat_nchw42nchw32,
&try_conv_reformat_nchw42nchw64,
#if CUDA_VERSION >= 10020
&try_conv_reformat_nchw322nchw4,
#endif
&rewriter](OperatorNodeBase* opr) {
if (!try_conv_dimshuffle_reshape_typecvt(opr) &&
!try_conv_reformat_nchw42nchw32(opr)
!try_conv_reformat_nchw42nchw32(opr) &&
!try_conv_reformat_nchw42nchw64(opr)
#if CUDA_VERSION >= 10020
&& !try_conv_reformat_nchw322nchw4(opr)
#endif
......
......@@ -4400,6 +4400,93 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) {
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse);
}
TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NHWC) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
cn.activate();
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
auto sm_ver = prop.major * 10 + prop.minor;
if (sm_ver < 75) {
printf("This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)\n",
sm_ver, 75);
return;
}
HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto x = mkvar("x", {32, 4, 23, 40}, dtype::QuantizedS8(2.5f)),
w = mkcvar("w", {64, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)),
w1 = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS4(1.234f)),
b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32(12.34567f*1.234f));
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
auto y = opr::ConvBias::make(
x, w, b, param, {},
OperatorNodeConfig{dtype::QuantizedS8(12.34567f)});
y = opr::TypeCvt::make(y, dtype::QuantizedS4(12.34567f));
y = opr::ConvBias::make(y, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS4(56.71234f)});
y = opr::TypeCvt::make(y, dtype::Float32());
SymbolVar y_fuse, y_non_fuse;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw64();
unpack_vector(gopt::optimize_for_inference({y}, options), y_fuse);
}
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({y_fuse}, strategy);
HostTensorND host_y_fuse;
auto func1 = graph->compile({make_callback_copy(y_fuse, host_y_fuse)});
func1->execute();
graph->compile({{y_fuse, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.FoldingConvDimshuffleNCHW4NHWC.json"));
size_t nr_dimshuffle = find_opr_num<opr::TypeCvt>(y_fuse);
printf("%zu \n", nr_dimshuffle);
ASSERT_EQ(3u, find_opr_num<opr::Dimshuffle>(y_fuse));
bool found = false;
cg::DepOprIter{[&found](cg::OperatorNodeBase* opr) {
if (!found && opr->same_type<opr::ConvBias>()) {
opr::ConvBias* cb = &opr->cast_final_safe<opr::ConvBias>();
if (cb->param().format == opr::ConvBias::Param::Format::NCHW4_NHWC)
found = true;
}
}}
.add(y_fuse.node()->owner_opr());
EXPECT_TRUE(found);
unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(),
y_non_fuse);
gopt::modify_opr_algo_strategy_inplace({y_non_fuse}, strategy);
HostTensorND host_y_non_fuse;
auto func2 =
graph->compile({make_callback_copy(y_non_fuse, host_y_non_fuse)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse);
}
#endif
TEST(TestGoptInference, PaddingChannels) {
......
......@@ -864,7 +864,13 @@ void ConvBiasForward::init_output_static_infer_desc() {
void ConvBiasForward::init_output_format() {
mgb_assert(output().size() == 2);
output(0)->format(input(0)->format());
auto format = input(0)->format();
if (!format.is_default() && !format.is_lowbit_aligned()) { // propagate
output(0)->format(input(0)->format());
} else {
mgb_assert(output(0)->dtype().valid());
output(0)->format(TensorFormat(output(0)->dtype()));
}
}
void ConvBiasForward::check_winograd_param_valid(
......
......@@ -147,9 +147,11 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
packed_size = 32;
} else {
mgb_assert(param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NCHW32,
"format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32");
"format should be "
"NCHW4/NCHW4_NCHW/NCHW4_NHWC/NCHW4_NCHW32");
packed_size = 4;
}
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group *
......@@ -174,6 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册