提交 0fb9cc41 编写于 作者: M Megvii Engine Team

fix(gopt): fix nchw64 opt pass

GitOrigin-RevId: dec18d1ab1b7bd0723395e490c215356f178e44a
上级 e661ae90
......@@ -36,15 +36,23 @@ struct SubGraphMatcher {
Node(Typeinfo* in_op_type) : op_type(in_op_type){};
Node(Typeinfo* in_op_type, CallBack func)
: op_type(in_op_type), cbk(func){};
Node(Typeinfo* in_op_type, std::vector<Node> in_pre_node)
Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node)
: op_type(in_op_type), pre_node(in_pre_node){};
Node(Typeinfo* in_op_type, std::vector<Node> in_pre_node, CallBack func)
Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node,
CallBack func)
: op_type(in_op_type), pre_node(in_pre_node), cbk(func){};
Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node,
CallBack func, std::string in_msg)
: op_type(in_op_type),
pre_node(in_pre_node),
cbk(func),
msg(in_msg){};
Typeinfo* op_type{nullptr};
std::vector<Node> pre_node;
std::vector<std::vector<Node>> pre_node;
//! cbk used to check param and gather args for creating fusion op
CallBack cbk;
std::string msg{""};
};
bool match(Node& root, OperatorNodeBase* opr) {
......@@ -53,20 +61,34 @@ struct SubGraphMatcher {
}
//! match nullptr node always
if (root.op_type == nullptr || root.op_type == opr->dyn_typeinfo()) {
bool match_ok = true;
bool current_match = true;
if (root.cbk)
match_ok &= root.cbk(opr);
RETURN_IF_FALSE(match_ok);
current_match &= root.cbk(opr);
RETURN_IF_FALSE(current_match);
auto& inp = opr->input();
for (size_t node_idx = 0; node_idx < root.pre_node.size();
++node_idx) {
bool valid_node_idx = node_idx < inp.size();
RETURN_IF_FALSE(valid_node_idx);
match_ok &= match(root.pre_node[node_idx],
inp[node_idx]->owner_opr());
RETURN_IF_FALSE(match_ok);
bool any_sub_patten_match =
root.pre_node.size() == 0 ? true : false;
for (auto& sub_patten : root.pre_node) {
bool patten_ok = true;
for (size_t node_idx = 0; node_idx < sub_patten.size();
++node_idx) {
bool valid_node_idx = node_idx < inp.size();
if (!valid_node_idx) {
patten_ok = false;
break;
}
patten_ok = patten_ok && match(sub_patten[node_idx],
inp[node_idx]->owner_opr());
if (!patten_ok) {
break;
}
}
any_sub_patten_match = any_sub_patten_match || patten_ok;
if (any_sub_patten_match) {
break;
}
}
return match_ok;
return current_match && any_sub_patten_match;
} else {
return false;
}
......@@ -237,24 +259,26 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
return false;
}
};
SGM::Node broadcast_or_immutable{nullptr, check_pad};
SGM::Node broadcast_or_immutable{
nullptr, {}, check_pad, "broadcast_or_immutable"};
SGM::Node broadcast_concat{
opr::Concat::typeinfo(),
{in_node, broadcast_or_immutable},
{{in_node, broadcast_or_immutable}},
[](OperatorNodeBase* opr) {
auto concat_pad = opr->try_cast_final<opr::Concat>();
return concat_pad->axis() == 1;
}};
},
"broadcast_concat"};
SGM::Node nchwx_reshape{opr::Reshape::typeinfo(),
{broadcast_concat, SGM::Node(nullptr)},
{{broadcast_concat, SGM::Node(nullptr)}},
[](OperatorNodeBase* opr) {
auto inp0 = opr->input()[0];
return is_shape_nchw(inp0->shape());
}};
SGM::Node shuffle_root{
opr::Dimshuffle::typeinfo(),
{nchwx_reshape},
{{nchwx_reshape}},
[](OperatorNodeBase* opr) {
auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>();
auto& input_vec = shuffle_opr.input();
......@@ -263,13 +287,55 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
}};
return shuffle_root;
};
auto gen_u8_cvt2_q8 = [](OperatorNodeBase*& src_node,
OperatorNodeBase*& neg_128_immu_node) {
SGM::Node input_data_u8{nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Uint8) {
src_node = opr;
return true;
} else {
return false;
}
}};
SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(),
{{input_data_u8}},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
bool is_fp32 = cvt_op->param().enumv() ==
DTypeEnum::Float32;
return is_fp32;
}};
SGM::Node sub_128{
opr::Elemwise::typeinfo(),
{{cvt_fp32, nullptr}, {nullptr, cvt_fp32}},
[&](OperatorNodeBase* opr) {
auto elem_op = opr->try_cast_final<opr::Elemwise>();
bool is_add_op = elem_op->param().mode ==
opr::Elemwise::Param::Mode::ADD;
auto neg_128_op = elem_op->input()[1]->owner_opr();
bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_op = elem_op->input()[0]->owner_opr();
is_neg_128 = is_neg_128 ||
is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr;
return is_add_op && is_neg_128;
},
"sub_128"};
return sub_128;
};
auto replace_shuffle_opr = [&](OperatorNodeBase* opr,
const VarNodeArray& new_inp,
SubGraph::Rewriter& rewriter,
ReaderType& reader) {
SGM matcher;
OperatorNodeBase* src_node = nullptr;
SGM::Node input_data_cp{
OperatorNodeBase* neg_128_immu_node = nullptr;
auto u8_q8_input = gen_u8_cvt2_q8(src_node, neg_128_immu_node);
SGM::Node input_data_qu8{
nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
......@@ -279,7 +345,18 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
return false;
}
}};
SGM::Node type_cvt{opr::TypeCvt::typeinfo(), {input_data_cp}};
SGM::Node type_cvt{opr::TypeCvt::typeinfo(),
{{input_data_qu8}, {u8_q8_input}},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
if (cvt_op) {
return cvt_op->param().enumv() ==
DTypeEnum::QuantizedS8;
} else {
return false;
}
}};
SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
bool is_i32_pad = is_immutable_all_equal<dtype::Int32>(opr, 0);
......@@ -321,37 +398,7 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
OperatorNodeBase* neg_128_immu_node = nullptr;
OperatorNodeBase* pad0_immu_node = nullptr;
OperatorNodeBase* const_reshape_last_dim_node = nullptr;
SGM::Node input_data_cp{nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Uint8) {
src_node = opr;
return true;
} else {
return false;
}
}};
SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(),
{input_data_cp},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
bool is_fp32 = cvt_op->param().enumv() ==
DTypeEnum::Float32;
return is_fp32;
}};
SGM::Node sub_128{
opr::Elemwise::typeinfo(),
{cvt_fp32},
[&](OperatorNodeBase* opr) {
auto elem_op = opr->try_cast_final<opr::Elemwise>();
bool is_add_op = elem_op->param().mode ==
opr::Elemwise::Param::Mode::ADD;
auto neg_128_op = elem_op->input()[1]->owner_opr();
bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr;
return is_add_op && is_neg_128;
}};
auto sub_128 = gen_u8_cvt2_q8(src_node, neg_128_immu_node);
SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
pad0_immu_node = opr;
bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
......@@ -364,8 +411,16 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
};
auto&& shuffle_root = gen_pad_dimshuffle_graph(sub_128, const_pad_cbk,
const_reshape_cbk);
SGM::Node astype_root{opr::TypeCvt::typeinfo(), {shuffle_root}};
SGM::Node::CallBack cvt_q8_cbk = [](OperatorNodeBase* opr) {
auto cvt_op = opr->try_cast_final<opr::TypeCvt>();
if (cvt_op) {
return cvt_op->param().enumv() == DTypeEnum::QuantizedS8;
} else {
return false;
}
};
SGM::Node astype_root{
opr::TypeCvt::typeinfo(), {{shuffle_root}}, cvt_q8_cbk};
bool match = matcher.match(astype_root, opr);
bool check_ok = false;
if (match) {
......
此差异已折叠。
......@@ -3815,7 +3815,7 @@ TEST(TestGoptInference, PreProcessCase1) {
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
......@@ -3882,6 +3882,68 @@ TEST(TestGoptInference, WarpAndPreProcessCase0) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
TEST(TestGoptInference, PreProcessCaseAutopadNCHW64) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("gpu0");
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
auto sm_ver = prop.major * 10 + prop.minor;
if (sm_ver < 75) {
printf("This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)\n",
sm_ver, 75);
return;
}
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
size_t n = 2;
size_t c = 3;
size_t h = 32;
size_t w = 32;
auto host_x1 = gen({n, c, h, w}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_u8_fp32 = opr::TypeCvt::make(x, dtype::Float32(), cn);
auto x_s8_fp32 = x_u8_fp32 - 128;
auto x_s8 = opr::TypeCvt::make(x_s8_fp32, dtype::QuantizedS8(2.5f), cn);
auto weight = mkcvar("weight", {16, 3, 3, 3}, dtype::QuantizedS8(2.5f)),
bias = mkcvar("bias", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f));
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 2;
param.pad_h = param.pad_w = 1;
auto result =
opr::ConvBias::make(x_s8, weight, bias, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw64();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.PreProcessCaseAutopadNCHW64.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
ASSERT_TRUE(find_opr<opr::RelayoutFormat>(y_opt).param().mode ==
opr::RelayoutFormat::Param::Mode::NCHW_NCHW4);
}
TEST(TestGoptInference, WarpAndPreProcessCase1) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册