diff --git a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp index 6cf9b75d5ae8bb09230b2aa925f7500fc3459eb2..4e45cc5084ba7bb4c93a269fc7cc7e5f89d3b934 100644 --- a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp +++ b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp @@ -36,15 +36,23 @@ struct SubGraphMatcher { Node(Typeinfo* in_op_type) : op_type(in_op_type){}; Node(Typeinfo* in_op_type, CallBack func) : op_type(in_op_type), cbk(func){}; - Node(Typeinfo* in_op_type, std::vector in_pre_node) + Node(Typeinfo* in_op_type, std::vector> in_pre_node) : op_type(in_op_type), pre_node(in_pre_node){}; - Node(Typeinfo* in_op_type, std::vector in_pre_node, CallBack func) + Node(Typeinfo* in_op_type, std::vector> in_pre_node, + CallBack func) : op_type(in_op_type), pre_node(in_pre_node), cbk(func){}; + Node(Typeinfo* in_op_type, std::vector> in_pre_node, + CallBack func, std::string in_msg) + : op_type(in_op_type), + pre_node(in_pre_node), + cbk(func), + msg(in_msg){}; Typeinfo* op_type{nullptr}; - std::vector pre_node; + std::vector> pre_node; //! cbk used to check param and gather args for creating fusion op CallBack cbk; + std::string msg{""}; }; bool match(Node& root, OperatorNodeBase* opr) { @@ -53,20 +61,34 @@ struct SubGraphMatcher { } //! match nullptr node always if (root.op_type == nullptr || root.op_type == opr->dyn_typeinfo()) { - bool match_ok = true; + bool current_match = true; if (root.cbk) - match_ok &= root.cbk(opr); - RETURN_IF_FALSE(match_ok); + current_match &= root.cbk(opr); + RETURN_IF_FALSE(current_match); auto& inp = opr->input(); - for (size_t node_idx = 0; node_idx < root.pre_node.size(); - ++node_idx) { - bool valid_node_idx = node_idx < inp.size(); - RETURN_IF_FALSE(valid_node_idx); - match_ok &= match(root.pre_node[node_idx], - inp[node_idx]->owner_opr()); - RETURN_IF_FALSE(match_ok); + bool any_sub_patten_match = + root.pre_node.size() == 0 ? true : false; + for (auto& sub_patten : root.pre_node) { + bool patten_ok = true; + for (size_t node_idx = 0; node_idx < sub_patten.size(); + ++node_idx) { + bool valid_node_idx = node_idx < inp.size(); + if (!valid_node_idx) { + patten_ok = false; + break; + } + patten_ok = patten_ok && match(sub_patten[node_idx], + inp[node_idx]->owner_opr()); + if (!patten_ok) { + break; + } + } + any_sub_patten_match = any_sub_patten_match || patten_ok; + if (any_sub_patten_match) { + break; + } } - return match_ok; + return current_match && any_sub_patten_match; } else { return false; } @@ -237,24 +259,26 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { return false; } }; - SGM::Node broadcast_or_immutable{nullptr, check_pad}; + SGM::Node broadcast_or_immutable{ + nullptr, {}, check_pad, "broadcast_or_immutable"}; SGM::Node broadcast_concat{ opr::Concat::typeinfo(), - {in_node, broadcast_or_immutable}, + {{in_node, broadcast_or_immutable}}, [](OperatorNodeBase* opr) { auto concat_pad = opr->try_cast_final(); return concat_pad->axis() == 1; - }}; + }, + "broadcast_concat"}; SGM::Node nchwx_reshape{opr::Reshape::typeinfo(), - {broadcast_concat, SGM::Node(nullptr)}, + {{broadcast_concat, SGM::Node(nullptr)}}, [](OperatorNodeBase* opr) { auto inp0 = opr->input()[0]; return is_shape_nchw(inp0->shape()); }}; SGM::Node shuffle_root{ opr::Dimshuffle::typeinfo(), - {nchwx_reshape}, + {{nchwx_reshape}}, [](OperatorNodeBase* opr) { auto& shuffle_opr = opr->cast_final(); auto& input_vec = shuffle_opr.input(); @@ -263,13 +287,55 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { }}; return shuffle_root; }; + auto gen_u8_cvt2_q8 = [](OperatorNodeBase*& src_node, + OperatorNodeBase*& neg_128_immu_node) { + SGM::Node input_data_u8{nullptr, [&](OperatorNodeBase* opr) { + auto src_dtype = opr->output()[0]->dtype(); + if (src_dtype.enumv() == DTypeEnum::Uint8) { + src_node = opr; + return true; + } else { + return false; + } + }}; + SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(), + {{input_data_u8}}, + [](OperatorNodeBase* opr) { + auto cvt_op = + opr->try_cast_final(); + bool is_fp32 = cvt_op->param().enumv() == + DTypeEnum::Float32; + return is_fp32; + }}; + SGM::Node sub_128{ + opr::Elemwise::typeinfo(), + {{cvt_fp32, nullptr}, {nullptr, cvt_fp32}}, + [&](OperatorNodeBase* opr) { + auto elem_op = opr->try_cast_final(); + bool is_add_op = elem_op->param().mode == + opr::Elemwise::Param::Mode::ADD; + auto neg_128_op = elem_op->input()[1]->owner_opr(); + bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f, + DTypeEnum::Float32); + neg_128_op = elem_op->input()[0]->owner_opr(); + is_neg_128 = is_neg_128 || + is_immutable_equal(neg_128_op, -128.f, + DTypeEnum::Float32); + neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr; + return is_add_op && is_neg_128; + }, + "sub_128"}; + return sub_128; + }; auto replace_shuffle_opr = [&](OperatorNodeBase* opr, const VarNodeArray& new_inp, SubGraph::Rewriter& rewriter, ReaderType& reader) { SGM matcher; OperatorNodeBase* src_node = nullptr; - SGM::Node input_data_cp{ + OperatorNodeBase* neg_128_immu_node = nullptr; + auto u8_q8_input = gen_u8_cvt2_q8(src_node, neg_128_immu_node); + SGM::Node input_data_qu8{ nullptr, [&](OperatorNodeBase* opr) { auto src_dtype = opr->output()[0]->dtype(); if (src_dtype.enumv() == DTypeEnum::Quantized8Asymm) { @@ -279,7 +345,18 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { return false; } }}; - SGM::Node type_cvt{opr::TypeCvt::typeinfo(), {input_data_cp}}; + SGM::Node type_cvt{opr::TypeCvt::typeinfo(), + {{input_data_qu8}, {u8_q8_input}}, + [](OperatorNodeBase* opr) { + auto cvt_op = + opr->try_cast_final(); + if (cvt_op) { + return cvt_op->param().enumv() == + DTypeEnum::QuantizedS8; + } else { + return false; + } + }}; SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) { bool is_fp32_pad = is_immutable_all_equal(opr, 0); bool is_i32_pad = is_immutable_all_equal(opr, 0); @@ -321,37 +398,7 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { OperatorNodeBase* neg_128_immu_node = nullptr; OperatorNodeBase* pad0_immu_node = nullptr; OperatorNodeBase* const_reshape_last_dim_node = nullptr; - SGM::Node input_data_cp{nullptr, [&](OperatorNodeBase* opr) { - auto src_dtype = opr->output()[0]->dtype(); - if (src_dtype.enumv() == DTypeEnum::Uint8) { - src_node = opr; - return true; - } else { - return false; - } - }}; - SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(), - {input_data_cp}, - [](OperatorNodeBase* opr) { - auto cvt_op = - opr->try_cast_final(); - bool is_fp32 = cvt_op->param().enumv() == - DTypeEnum::Float32; - return is_fp32; - }}; - SGM::Node sub_128{ - opr::Elemwise::typeinfo(), - {cvt_fp32}, - [&](OperatorNodeBase* opr) { - auto elem_op = opr->try_cast_final(); - bool is_add_op = elem_op->param().mode == - opr::Elemwise::Param::Mode::ADD; - auto neg_128_op = elem_op->input()[1]->owner_opr(); - bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f, - DTypeEnum::Float32); - neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr; - return is_add_op && is_neg_128; - }}; + auto sub_128 = gen_u8_cvt2_q8(src_node, neg_128_immu_node); SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) { pad0_immu_node = opr; bool is_fp32_pad = is_immutable_all_equal(opr, 0); @@ -364,8 +411,16 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { }; auto&& shuffle_root = gen_pad_dimshuffle_graph(sub_128, const_pad_cbk, const_reshape_cbk); - - SGM::Node astype_root{opr::TypeCvt::typeinfo(), {shuffle_root}}; + SGM::Node::CallBack cvt_q8_cbk = [](OperatorNodeBase* opr) { + auto cvt_op = opr->try_cast_final(); + if (cvt_op) { + return cvt_op->param().enumv() == DTypeEnum::QuantizedS8; + } else { + return false; + } + }; + SGM::Node astype_root{ + opr::TypeCvt::typeinfo(), {{shuffle_root}}, cvt_q8_cbk}; bool match = matcher.match(astype_root, opr); bool check_ok = false; if (match) { diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 1b8a064e6f5ad8fe54b368f2268d7f1d4e3d4439..9eaa973337c1a18747ef4e363f4f6212cced9f92 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -206,7 +206,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { NCHW_TO_NCHW4_IC_SMALL_CONV) { if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0, + "src shape %s", inp_shape.to_string().c_str()); } else { mgb_assert(layout_type() == RelayoutPlaceholder::LayoutType:: @@ -411,7 +412,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[4] = 32; } else if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0); + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", inp_shape.to_string().c_str()); dst.ndim = 5; dst[0] = inp_shape[0]; dst[1] = inp_shape[1] / 64; @@ -4191,12 +4192,12 @@ void PaddingChannelPass::apply(OptState& opt) const { VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { - auto iter = m_opr_format_map.find(orig_var->owner_opr()); + auto iter = m_opr_format_map.find(new_var->owner_opr()); mgb_assert(iter != m_opr_format_map.end(), "cannot find opr(type:%s,name:%s) information, related " "output var node(name:%s)", - orig_var->owner_opr()->dyn_typeinfo()->name, - orig_var->owner_opr()->cname(), orig_var->cname()); + new_var->owner_opr()->dyn_typeinfo()->name, + new_var->owner_opr()->cname(), new_var->cname()); const auto& fmt = iter->second; using LayoutType = RelayoutPlaceholder::LayoutType; LayoutType type; @@ -4253,20 +4254,70 @@ EnableNCHW64Pass::make_nchw64_converter() { return new_conv.node(); } }; + auto try_transform_to_nchw = + [&format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + mgb_assert(opr->input().size()==new_inp.size()); + bool check_dtype = + new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && + new_inp[1]->dtype().enumv() == DTypeEnum::Float32; + if (opr->input().size() >= 3) + check_dtype &= + new_inp[2]->dtype().enumv() == DTypeEnum::Float32; + if (opr->input().size() >= 4) + check_dtype &= + new_inp[3]->dtype().enumv() == DTypeEnum::Float32; + if (!check_dtype) + return nullptr; + auto inps = new_inp; + auto process = [&](size_t i) -> VarNode* { + auto iter = format_map.find(new_inp[i]->owner_opr()); + if (iter == format_map.end()) { + return inps[i]; + } else { + const auto& fmt = iter->second; + if (fmt == Format::NCHW32) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW); + return ovar.node(); + } else if (fmt == Format::NCHW4) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW); + return ovar.node(); + } else { + mgb_assert(fmt == Format::NCHW64); + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW); + return ovar.node(); + } + } + }; + for (size_t i = 0; i < inps.size(); ++i) { + inps[i] = process(i); + } + auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); + return ret->output()[0]; + }; + auto try_transform_to_nchw4 = [make_new_conv, &format_map]( OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = - opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 && - opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8; + new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && + new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; if (opr->input().size() >= 3) check_dtype &= - opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; + new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) check_dtype &= - opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8; + new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4277,7 +4328,7 @@ EnableNCHW64Pass::make_nchw64_converter() { opr->cname(), out_channels, in_channels); auto inps = new_inp; auto process = [&](size_t i) -> VarNode* { - auto iter = format_map.find(opr->input(i)->owner_opr()); + auto iter = format_map.find(new_inp[i]->owner_opr()); if (iter == format_map.end()) { auto ovar = RelayoutPlaceholder::make( inps[i], @@ -4304,24 +4355,26 @@ EnableNCHW64Pass::make_nchw64_converter() { for (size_t i = 0; i < inps.size(); ++i) { inps[i] = process(i); } - format_map.insert(std::make_pair(opr, Format::NCHW4)); auto& conv_bias = opr->cast_final_safe(); - return make_new_conv(inps, &conv_bias, Format::NCHW4); + auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4); + format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); + return ret; }; auto try_transform_to_nchw32 = [make_new_conv, &format_map]( OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = - opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 && - opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8; + new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && + new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; if (opr->input().size() >= 3) check_dtype &= - opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; + new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) check_dtype &= - opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8; + new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4331,7 +4384,7 @@ EnableNCHW64Pass::make_nchw64_converter() { return nullptr; auto inps = new_inp; auto process = [&](size_t i) -> VarNode* { - auto iter = format_map.find(opr->input(i)->owner_opr()); + auto iter = format_map.find(new_inp[i]->owner_opr()); if (iter == format_map.end()) { auto ovar = RelayoutPlaceholder::make( inps[i], @@ -4358,9 +4411,10 @@ EnableNCHW64Pass::make_nchw64_converter() { for (size_t i = 0; i < inps.size(); ++i) { inps[i] = process(i); } - format_map.insert(std::make_pair(opr, Format::NCHW32)); auto& conv_bias = opr->cast_final_safe(); - return make_new_conv(inps, &conv_bias, Format::NCHW32); + auto ret = make_new_conv(inps, &conv_bias, Format::NCHW32); + format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW32)); + return ret; }; auto try_transform_to_nchw64 = @@ -4368,17 +4422,18 @@ EnableNCHW64Pass::make_nchw64_converter() { OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { // fint4XWint4 and fuint4XWint4 + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = - (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || - opr->input(0)->dtype().enumv() == + (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || + new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && - opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS4; + new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; if (opr->input().size() >= 3) check_dtype &= - opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; + new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) - check_dtype &= opr->input(3)->dtype().enumv() == - opr->input(0)->dtype().enumv(); + check_dtype &= new_inp[3]->dtype().enumv() == + new_inp[0]->dtype().enumv(); if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4388,7 +4443,7 @@ EnableNCHW64Pass::make_nchw64_converter() { return nullptr; auto inps = new_inp; auto process = [&](size_t i) -> VarNode* { - auto iter = format_map.find(opr->input(i)->owner_opr()); + auto iter = format_map.find(new_inp[i]->owner_opr()); if (iter == format_map.end()) { auto ovar = RelayoutPlaceholder::make( inps[i], @@ -4415,15 +4470,16 @@ EnableNCHW64Pass::make_nchw64_converter() { for (size_t i = 0; i < inps.size(); ++i) { inps[i] = process(i); } - format_map.insert(std::make_pair(opr, Format::NCHW64)); auto& conv_bias = opr->cast_final_safe(); - return make_new_conv(inps, &conv_bias, Format::NCHW64); + auto ret = make_new_conv(inps, &conv_bias, Format::NCHW64); + format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW64)); + return ret; }; // replace rule for conv bias opr auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4, try_transform_to_nchw32, - try_transform_to_nchw64]( + try_transform_to_nchw64, try_transform_to_nchw]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { using Param = megdnn::param::ConvBias; @@ -4435,16 +4491,18 @@ EnableNCHW64Pass::make_nchw64_converter() { VarNode* new_var = nullptr; if ((new_var = try_transform_to_nchw32(opr, new_inp)) || (new_var = try_transform_to_nchw4(opr, new_inp)) || - (new_var = try_transform_to_nchw64(opr, new_inp))) { + (new_var = try_transform_to_nchw64(opr, new_inp))|| + (new_var = try_transform_to_nchw(opr, new_inp))) { return new_var->owner_opr(); } else { mgb_assert( - opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && - opr->input(0)->dtype().enumv() != + new_inp[0]->dtype().enumv() != DTypeEnum::QuantizedS8 && + new_inp[0]->dtype().enumv() != DTypeEnum::QuantizedS4 && - opr->input(0)->dtype().enumv() != - DTypeEnum::Quantized4Asymm, - "invalid data type(%s)", opr->input(0)->dtype().name()); + new_inp[0]->dtype().enumv() != + DTypeEnum::Quantized4Asymm && + new_inp[0]->dtype().enumv() != DTypeEnum::Float32, + "invalid data type(%s)", new_inp[0]->dtype().name()); bool shape_changed = false; for (const auto& i : new_inp) { if (format_map.count(i->owner_opr()) > 0) { @@ -4471,9 +4529,9 @@ EnableNCHW64Pass::make_nchw64_converter() { "only have 2 input vars(got:%zu)", new_inp.size()); auto& deconv = opr->cast_final_safe(); - if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { Format cur; - auto iter = format_map.find(opr->input(1)->owner_opr()); + auto iter = format_map.find(new_inp[1]->owner_opr()); if (iter == format_map.end()) { cur = Format::NCHW; } else { @@ -4506,13 +4564,15 @@ EnableNCHW64Pass::make_nchw64_converter() { default: mgb_assert(cur == Format::NCHW4); } - format_map.insert(std::make_pair(opr, Format::NCHW4)); + auto param = deconv.param(); param.format = Format::NCHW4; auto new_deconv = opr::ConvolutionBackwardData::make( inps[0], inps[1], param, deconv.execution_policy(), deconv.config()); - return new_deconv.node()->owner_opr(); + auto ret = new_deconv.node()->owner_opr(); + format_map.insert(std::make_pair(ret, Format::NCHW4)); + return ret; } else { bool shape_changed = false; for (const auto& i : new_inp) { @@ -4538,7 +4598,7 @@ EnableNCHW64Pass::make_nchw64_converter() { bool same_format = true; bool first_touch = false; Format format(Format::NCHW); - for (const auto& i : opr->input()) { + for (const auto& i : new_inp) { Format cur; auto iter = format_map.find(i->owner_opr()); if (iter == format_map.end()) { @@ -4557,10 +4617,11 @@ EnableNCHW64Pass::make_nchw64_converter() { } } if (same_format) { + auto ret = serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); if (format != Format::NCHW) - format_map.insert(std::make_pair(opr, format)); - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); + format_map.insert(std::make_pair(ret, format)); + return ret; } Format max_format(Format::NCHW); @@ -4592,7 +4653,7 @@ EnableNCHW64Pass::make_nchw64_converter() { }; auto inps = new_inp; for (size_t i = 0; i < opr->input().size(); ++i) { - auto iter = format_map.find(opr->input(i)->owner_opr()); + auto iter = format_map.find(new_inp[i]->owner_opr()); Format cur; if (iter != format_map.end()) { cur = iter->second; @@ -4603,9 +4664,10 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); } } + auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); if (max_format != Format::NCHW) - format_map.insert(std::make_pair(opr, max_format)); - return serialization::copy_opr_shallow(*opr, inps, opr->config()); + format_map.insert(std::make_pair(ret, max_format)); + return ret; }; // elemwise like replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; @@ -4619,10 +4681,10 @@ EnableNCHW64Pass::make_nchw64_converter() { const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& warp = opr->cast_final_safe(); - if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || - opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || + new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) { Format cur; - auto iter = format_map.find(opr->input(0)->owner_opr()); + auto iter = format_map.find(new_inp[0]->owner_opr()); if (iter == format_map.end()) { cur = Format::NCHW; } else { @@ -4651,7 +4713,6 @@ EnableNCHW64Pass::make_nchw64_converter() { default: mgb_assert(cur == Format::NCHW64); } - format_map.insert(std::make_pair(opr, Format::NCHW64)); auto param = warp.param(); param.format = Format::NCHW64; SymbolVar new_warp; @@ -4665,10 +4726,12 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[0], inps[1], inps[2], inps[3], param, warp.config()); } - return new_warp.node()->owner_opr(); - } else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + auto ret = new_warp.node()->owner_opr(); + format_map.insert(std::make_pair(ret, Format::NCHW64)); + return ret; + } else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { Format cur; - auto iter = format_map.find(opr->input(0)->owner_opr()); + auto iter = format_map.find(new_inp[0]->owner_opr()); if (iter == format_map.end()) { cur = Format::NCHW; } else { @@ -4697,7 +4760,7 @@ EnableNCHW64Pass::make_nchw64_converter() { default: mgb_assert(cur == Format::NCHW4); } - format_map.insert(std::make_pair(opr, Format::NCHW4)); + auto param = warp.param(); param.format = Format::NCHW4; SymbolVar new_warp; @@ -4711,7 +4774,9 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[0], inps[1], inps[2], inps[3], param, warp.config()); } - return new_warp.node()->owner_opr(); + auto ret = new_warp.node()->owner_opr(); + format_map.insert(std::make_pair(ret, Format::NCHW4)); + return ret; } else { bool shape_changed = false; for (const auto& i : new_inp) { @@ -4733,10 +4798,10 @@ EnableNCHW64Pass::make_nchw64_converter() { const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& pooling = opr->cast_final_safe(); - if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || - opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || + new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) { Format cur; - auto iter = format_map.find(opr->input(0)->owner_opr()); + auto iter = format_map.find(new_inp[0]->owner_opr()); if (iter == format_map.end()) { cur = Format::NCHW; } else { @@ -4765,21 +4830,23 @@ EnableNCHW64Pass::make_nchw64_converter() { default: mgb_assert(cur == Format::NCHW64); } - format_map.insert(std::make_pair(opr, Format::NCHW64)); + auto param = pooling.param(); param.format = Format::NCHW64; auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); - return new_pool.node()->owner_opr(); - } else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + auto ret = new_pool.node()->owner_opr(); + format_map.insert(std::make_pair(ret, Format::NCHW64)); + return ret; + } else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { Format cur; - auto iter = format_map.find(opr->input(0)->owner_opr()); + auto iter = format_map.find(new_inp[0]->owner_opr()); if (iter == format_map.end()) { cur = Format::NCHW; } else { cur = iter->second; } - size_t in_channels = opr->input(0)->shape()[1]; + size_t in_channels = new_inp[0]->shape()[1]; bool use_nchw32 = false; auto inps = new_inp; using LayoutType = RelayoutPlaceholder::LayoutType; @@ -4805,12 +4872,14 @@ EnableNCHW64Pass::make_nchw64_converter() { mgb_assert(cur == Format::NCHW4); } Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; - format_map.insert(std::make_pair(opr, out_format)); + auto param = pooling.param(); param.format = out_format; auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); - return new_pool.node()->owner_opr(); + auto ret = new_pool.node()->owner_opr(); + format_map.insert(std::make_pair(ret, out_format)); + return ret; } else { bool shape_changed = false; for (const auto& i : new_inp) { @@ -4838,9 +4907,9 @@ EnableNCHW64Pass::make_nchw64_converter() { mgb_assert(opr->input().size() == new_inp.size()); auto inps = new_inp; for (size_t i = 0; i < opr->input().size(); ++i) { - auto iter = format_map.find(opr->input(i)->owner_opr()); + auto iter = format_map.find(new_inp[i]->owner_opr()); + auto fmt = iter != format_map.end()?iter->second:Format::NCHW; if (iter != format_map.end()) { - auto fmt = iter->second; switch (fmt) { case Format::NCHW4: inps[i] = RelayoutPlaceholder::make( @@ -4867,7 +4936,8 @@ EnableNCHW64Pass::make_nchw64_converter() { } } } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); + auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); + return ret; }; replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 447f0f53308b7fb66801f3a289b57581e6a007fa..d6cb2af8ed1cbe69f915f396d062b24784097e2b 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -3815,7 +3815,7 @@ TEST(TestGoptInference, PreProcessCase1) { HostTensorND host_y_opt, host_y; auto func = graph->compile({make_callback_copy(y, host_y), - make_callback_copy(y_opt, host_y_opt)}); + make_callback_copy(y_opt, host_y_opt)}); func->execute(); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); @@ -3882,6 +3882,68 @@ TEST(TestGoptInference, WarpAndPreProcessCase0) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); } +TEST(TestGoptInference, PreProcessCaseAutopadNCHW64) { + REQUIRE_GPU(1); + HostTensorGenerator gen(0, 255); + auto cn = CompNode::load("gpu0"); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 75) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 75); + return; + } + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + size_t n = 2; + size_t c = 3; + size_t h = 32; + size_t w = 32; + auto host_x1 = gen({n, c, h, w}, cn); + + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + auto x_u8_fp32 = opr::TypeCvt::make(x, dtype::Float32(), cn); + auto x_s8_fp32 = x_u8_fp32 - 128; + auto x_s8 = opr::TypeCvt::make(x_s8_fp32, dtype::QuantizedS8(2.5f), cn); + auto weight = mkcvar("weight", {16, 3, 3, 3}, dtype::QuantizedS8(2.5f)), + bias = mkcvar("bias", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + auto result = + opr::ConvBias::make(x_s8, weight, bias, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw64(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.PreProcessCaseAutopadNCHW64.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); + ASSERT_TRUE(find_opr(y_opt).param().mode == + opr::RelayoutFormat::Param::Mode::NCHW_NCHW4); +} + TEST(TestGoptInference, WarpAndPreProcessCase1) { REQUIRE_GPU(1); HostTensorGenerator gen(0, 255);