diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 477c91909c62fe7162cbc93580583cd60d3e3277..96d8360127cb107234caa46851e4d72e1911cc75 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -194,6 +194,26 @@ R"__usage__( Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used on Nvidia GPUs, whose compute capability is above 6.1. )__usage__" +R"__usage__( + --enable-nchw44 + Execute operators with kernels implemented in MegDNN with NCHW44 tensor format. This can only + be used on arm of armv7 and arm64, support data tyep of float32, qint8 and int8x8x16. +)__usage__" +R"__usage__( + --enable-nhw88 + Execute operators with kernels implemented in MegDNN with NCHW88 tensor format. This can only + be used on x86 with data type float. +)__usage__" +R"__usage__( + --enable-nhw44-dot + Execute operators with kernels implemented in MegDNN with NCHW44-DOT tensor format. This Can + only be used on arm32 and arm64 with dot-product supported, and only support qint8 model +)__usage__" +R"__usage__( + --weight-preprocess + Execute operators with weight preprocess, which can optimize the operator execution time with + algo of winograd, im2col ,etc., but it may consume more memory. +)__usage__" ; @@ -1226,6 +1246,11 @@ Args Args::from_argv(int argc, char **argv) { graph_opt.graph_opt.weight_winograd_transform = true; continue; } + if (!strcmp(argv[i], "--weight-preprocess")) { + mgb_log_warn("enable weight-preprocess optimization"); + graph_opt.graph_opt.enable_weight_preprocess(); + continue; + } fprintf(stderr, "invalid arg: %s\n", argv[i]); ret.args_parse_ret = -1; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 9b83e45989919e8dc22c5fd4640c541980e6a282..27b33565e01664a3b03c57732d5b972eb8fee75c 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -97,6 +97,9 @@ struct GraphCommonOptimizeOptions { bool fuse_conv_bias_with_z = false; //! whether to enable fast-run profiled winograd opr replace bool weight_winograd_transform = false; + //! whether to enable weight preprocess, if enabled it may use more + //! memory, default disable now + bool weight_preprocess = false; enum LayoutTransform : uint32_t { DEFAULT, NCHW4, ///< compute using NCHW4 tensor format @@ -127,6 +130,7 @@ struct GraphCommonOptimizeOptions { SET(fuse_conv_bias_nonlinearity); SET(fuse_conv_bias_with_z); SET(weight_winograd_transform); + SET(weight_preprocess); #undef SET #define SET(_trans, _trans_capital) \ GraphCommonOptimizeOptions& enable_##_trans() { \ diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index eda63b4a6e710a221acba71bfb2002b2e4b77252..f4a2abfb7b7bc2adbac4c8473e64dabaa591afd3 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -963,6 +963,9 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( const cg::OperatorNodeBase& opr) const { + if (!opr.owner_graph()->options().graph_opt.weight_preprocess) { + return false; + } if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) return false; if (cg::is_const_var_value(opr.input(1))) diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index aa79d0290d2dd215f0a0d75378141e2210d0b0c1..500efd925b5472c41d69cecf16065cadaa816576 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -2225,6 +2225,7 @@ protected: iw = ih; comp_node = CompNode::load("cpux"); graph = ComputingGraph::make(); + graph->options().graph_opt.weight_preprocess = is_weight_preprocess(); TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh}; x_host = std::make_shared(comp_node, x_shape); auto x = opr::Host2DeviceCopy::make(*graph, x_host); @@ -2247,6 +2248,8 @@ protected: void run() { func->execute().wait(); } + virtual bool is_weight_preprocess() { return true; } + void TearDown() override { func.reset(); // Triggers mock check @@ -2346,6 +2349,33 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) { } } +class TestNoWeightPreprocess : public TestWeightPreprocess { + bool is_weight_preprocess() override { return false; } +}; + +TEST_F(TestNoWeightPreprocess, NoPreprocess) { + using ::testing::_; + using ::testing::Return; + auto& mock = mock_conv(); + + MockAlgorithm algo; + EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) + .WillRepeatedly(Return(&algo)); + EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(mock, get_preprocess_workspace_in_bytes(_, _, _)) + .WillRepeatedly(Return(0)); + + { + ::testing::InSequence seq; + // Return empty preprocess filters, indicating no need to preprocess + EXPECT_CALL(mock, deduce_preprocessed_filter_layout(_, _, _)).Times(0); + EXPECT_CALL(mock, exec_preprocess(_, _, _, _, _)).Times(0); + EXPECT_CALL(mock, exec(_, _, _, nullptr, _)); + run(); + } +} + } // anonymous namespace #endif