From 4e9be159f715da4715d90a5a7baac2bb46816f23 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 24 Nov 2020 17:03:09 +0800 Subject: [PATCH] feat(mgb/gopt): add opt pass for fusing convolution and reformat GitOrigin-RevId: d0c5deace2e860cb62002a6cfedd4b32a8ca24df --- src/gopt/impl/framework.cpp | 1 + src/gopt/impl/tensor_reformat.cpp | 388 +++++++++++++++++++-- src/gopt/include/megbrain/gopt/inference.h | 6 + src/gopt/test/inference.cpp | 243 ++++++++++++- src/plugin/impl/opr_footprint.cpp | 8 +- third_party/cutlass | 2 +- 6 files changed, 624 insertions(+), 24 deletions(-) diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 1238cd27..e7cbe94d 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -759,6 +759,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( add_pass(); add_pass(FuseNCHW4Int8Preprocess::make()); add_pass(); + add_pass(); }); cb(chwn4, { add_pass(); diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 17256bf9..0050c65a 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -2825,27 +2825,26 @@ public: MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, cg::SingleCNOperatorNodeBase) // { public: -AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); - -static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); - -TensorFormat inp_format() const { - return m_inp_format; -} - -TensorFormat out_format() const { - return m_out_format; -} - + AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, + TensorFormat out_format); + + static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, + TensorFormat out_format); + + TensorFormat inp_format() const { + return m_inp_format; + } + + TensorFormat out_format() const { + return m_out_format; + } + private: -void init_output_static_infer_desc() override; -void scn_do_execute() override; -const TensorFormat m_inp_format; -const TensorFormat m_out_format; -} -; + void init_output_static_infer_desc() override; + void scn_do_execute() override; + const TensorFormat m_inp_format; + const TensorFormat m_out_format; +}; MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); @@ -3228,4 +3227,353 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const { MIDOUT_E } +/* ==================== FoldingConvBiasDimshufflePass ================= */ +const char* FoldingConvBiasDimshufflePass::name() const { + return mgb_cstr_log("folding conv bias dimshuffle pass"); +} + +void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { + MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); + using DepType = cg::OperatorNodeProp::DepType; + ThinHashMap>> + readers; + static const ThinHashSet opr_type_list = { + opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), + opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; + opt.graph().iter([&readers](OperatorNodeBase* opr) { + for (auto&& i : opr->node_prop().dep_map()) { + if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { + readers[i.first->owner_opr()].emplace_back(opr, i.second); + } + } + }); + + auto rewriter = opt.graph().make_rewriter(); + auto nchw42nchw = [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp); + auto y2 = opr::TypeCvt::make(y1, dtype::Float32()); + return y2.node(); + }; + + auto nchw42nchw32 = [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + + auto nchw322nchw4 = [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + + auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, + &nchw42nchw]( + OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + auto inp_dtype = typecvt->input(0)->dtype(), + out_dtype = typecvt->output(0)->dtype(); + bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + out_dtype.enumv() == DTypeEnum::Float32; + if (!is_s82f32) + return false; + opr_set.insert(opr); + + // check reshape + auto reshape = + try_cast_as_op(typecvt->input(0)->owner_opr()); + if (reshape == nullptr) + return false; + opr_set.insert(reshape); + for (auto&& i : readers[reshape]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + // check shuffle + auto shuffle = + try_cast_as_op(reshape->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 5) + return false; + bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 4 && param.pattern[3] == 2 && + param.pattern[4] == 3 && + shuffle->input(0)->shape()[4] == 4; + if (!is_nchw42nchw) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + // check conv bias + auto conv_bias = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW4; + if (!is_s8nchw4) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = nchw42nchw(bias); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{dtype::Float32()}); + rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + typecvt + " + "dimshuffle + " + "reshape to conv_bias(NCHW4_NCHW)")); + return true; + }; + + auto try_conv_reformat_nchw42nchw32 = [&rewriter, &nchw42nchw32, + &readers](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check reshape + auto reshape1 = try_cast_as_op(opr); + if (reshape1 == nullptr) + return false; + opr_set.insert(opr); + // check dimshuffle + auto shuffle = try_cast_as_op( + reshape1->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 6) + return false; + bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 3 && param.pattern[3] == 4 && + param.pattern[4] == 2 && param.pattern[5] == 5 && + shuffle->output(0)->shape()[5] == 4 && + shuffle->output(0)->shape()[4] == 8; + if (!is_nchw42nchw32) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check reshape + auto reshape2 = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (reshape2 == nullptr) + return false; + opr_set.insert(reshape2); + for (auto&& i : readers[reshape2]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check conv bias + auto conv_bias = + try_cast_as_op(reshape2->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW4; + if (!is_s8nchw4) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = nchw42nchw32(bias); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + conv_bias->config()); + rewriter.replace_var( + opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW4_NCHW32)")); + return true; + }; + + auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( + OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check reshape + auto reshape1 = try_cast_as_op(opr); + if (reshape1 == nullptr) + return false; + opr_set.insert(opr); + // check dimshuffle + auto shuffle = try_cast_as_op( + reshape1->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 6) + return false; + bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 4 && param.pattern[3] == 2 && + param.pattern[4] == 3 && param.pattern[5] == 5 && + shuffle->input(0)->shape()[5] == 4 && + shuffle->input(0)->shape()[4] == 8; + if (!is_nchw322nchw4) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check reshape + auto reshape2 = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (reshape2 == nullptr) + return false; + opr_set.insert(reshape2); + for (auto&& i : readers[reshape2]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check conv bias + auto conv_bias = + try_cast_as_op(reshape2->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW32; + if (!is_s8nchw32) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = nchw322nchw4(bias); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + conv_bias->config()); + rewriter.replace_var( + opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW32_NCHW4)")); + return true; + }; + MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); + + auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, + &try_conv_reformat_nchw42nchw32, +#if CUDA_VERSION >= 10020 + &try_conv_reformat_nchw322nchw4, +#endif + &rewriter](OperatorNodeBase* opr) { + if (!try_conv_dimshuffle_reshape_typecvt(opr) && + !try_conv_reformat_nchw42nchw32(opr) +#if CUDA_VERSION >= 10020 + && !try_conv_reformat_nchw322nchw4(opr) +#endif + ) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + + MIDOUT_E +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index b2dd957d..81cf753b 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -402,6 +402,12 @@ namespace gopt { void apply(OptState& opt) const override; }; + class FoldingConvBiasDimshufflePass final : public Pass { + public: + const char* name() const override; + void apply(OptState& opt) const override; + }; + } // namespace gopt } // namespace mgb diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 85875c82..06a0cbdb 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -3632,7 +3632,6 @@ TEST(TestGoptInference, ConvertFormatCD4GroupOneConv) { } #if MGB_CUDA - TEST(TestGoptInference, PreProcessCase0) { REQUIRE_GPU(1); HostTensorGenerator @@ -3783,5 +3782,247 @@ TEST(TestGoptInference, WarpAndPreProcessCase) { func->execute(); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); } + +TEST(TestGoptInference, FoldingConvDimshuffle) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 61) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 61); + return; + } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + auto nchw42nchw = [](SymbolVar x) { + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1; + }; + + auto x = mkvar("x", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW4; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y = opr::TypeCvt::make(y, dtype::Float32()); + y = nchw42nchw(y); + SymbolVar y_fuse, y_non_fuse; + unpack_vector(gopt::GraphOptimizer{} + .add_pass() + .add_pass() + .add_pass() + .apply({{y}}) + .endpoint_vars(), + y_fuse); + graph->compile({{y_fuse, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.FoldingConvDimshuffle.json")); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4_NCHW, + find_opr(y_fuse).param().format); + ASSERT_EQ(0u, find_opr_num(y_fuse)); + unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(), + y_non_fuse); + HostTensorND host_y_fuse, host_y_non_fuse; + auto func = + graph->compile({make_callback_copy(y_fuse, host_y_fuse), + make_callback_copy(y_non_fuse, host_y_non_fuse)}); + func->execute(); +} + +TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NCHW32) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 61) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 61); + return; + } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + auto nchw42nchw32 = [](SymbolVar x) { + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2; + }; + + auto x = mkvar("x", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW4; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y = nchw42nchw32(y); + y = opr::TypeCvt::make(y, dtype::Float32()); + SymbolVar y_fuse, y_non_fuse; + unpack_vector(gopt::GraphOptimizer{} + .add_pass() + .add_pass() + .apply({{y}}) + .endpoint_vars(), + y_fuse); + graph->compile({{y_fuse, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.FoldingConvDimshuffleNCHW4NCHW32.json")); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4_NCHW32, + find_opr(y_fuse).param().format); + ASSERT_EQ(0u, find_opr_num(y_fuse)); + unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(), + y_non_fuse); + HostTensorND host_y_fuse, host_y_non_fuse; + auto func = + graph->compile({make_callback_copy(y_fuse, host_y_fuse), + make_callback_copy(y_non_fuse, host_y_non_fuse)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); +} + +#if CUDA_VERSION >= 10020 +TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 75) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 75); + return; + } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), + w1 = mkcvar("w1", {16, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 4, 1, 1, 4}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW4; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + param.stride_h = param.stride_w = 1; + y = opr::ConvBias::make(y, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y = opr::TypeCvt::make(y, dtype::Float32()); + SymbolVar y_fuse, y_non_fuse; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw32().enable_fuse_conv_bias_nonlinearity(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_fuse); + } + graph->compile({{y_fuse, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.FoldingConvDimshuffleNCHW32NCHW4.json")); + ASSERT_EQ(1u, find_opr_num(y_fuse)); + bool found = false; + cg::DepOprIter{[&found](cg::OperatorNodeBase* opr) { + if (!found && opr->same_type()) { + opr::ConvBias* cb = &opr->cast_final_safe(); + if (cb->param().format == + opr::ConvBias::Param::Format::NCHW32_NCHW4) + found = true; + } + }} + .add(y_fuse.node()->owner_opr()); + EXPECT_TRUE(found); + unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(), + y_non_fuse); + HostTensorND host_y_fuse, host_y_non_fuse; + auto func = + graph->compile({make_callback_copy(y_fuse, host_y_fuse), + make_callback_copy(y_non_fuse, host_y_non_fuse)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); +} #endif +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 557ff467..730db81d 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -131,8 +131,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 / group * 2; } - mgb_assert(param.format == Param::Format::NCHW4, - "format should be NCHW4/NCHW32"); + mgb_assert(param.format == Param::Format::NCHW4 || + param.format == Param::Format::NCHW4_NCHW || + param.format == Param::Format::NCHW4_NCHW32, + "format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32"); return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 4 / group * 2; }; @@ -154,6 +156,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, 2; }; if (param.format == Param::Format::NCHW4 || + param.format == Param::Format::NCHW4_NCHW || + param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44_DOT || diff --git a/third_party/cutlass b/third_party/cutlass index 41426ea4..9f743167 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 41426ea4074dcfc448b1c9979ea7617407590c04 +Subproject commit 9f7431672c17d4a731f84ca9d8f3f4e741e267b1 -- GitLab