提交 add3a1bc 编写于 作者: M Megvii Engine Team

feat(mgb/opr): add weight preprocess option

GitOrigin-RevId: 9d83a174fad2c4d6a1bca86c8597c8b6b4544376
上级 ee2e2b3c
......@@ -194,6 +194,26 @@ R"__usage__(
Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used
on Nvidia GPUs, whose compute capability is above 6.1.
)__usage__"
R"__usage__(
--enable-nchw44
Execute operators with kernels implemented in MegDNN with NCHW44 tensor format. This can only
be used on arm of armv7 and arm64, support data tyep of float32, qint8 and int8x8x16.
)__usage__"
R"__usage__(
--enable-nhw88
Execute operators with kernels implemented in MegDNN with NCHW88 tensor format. This can only
be used on x86 with data type float.
)__usage__"
R"__usage__(
--enable-nhw44-dot
Execute operators with kernels implemented in MegDNN with NCHW44-DOT tensor format. This Can
only be used on arm32 and arm64 with dot-product supported, and only support qint8 model
)__usage__"
R"__usage__(
--weight-preprocess
Execute operators with weight preprocess, which can optimize the operator execution time with
algo of winograd, im2col ,etc., but it may consume more memory.
)__usage__"
;
......@@ -1226,6 +1246,11 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.graph_opt.weight_winograd_transform = true;
continue;
}
if (!strcmp(argv[i], "--weight-preprocess")) {
mgb_log_warn("enable weight-preprocess optimization");
graph_opt.graph_opt.enable_weight_preprocess();
continue;
}
fprintf(stderr, "invalid arg: %s\n", argv[i]);
ret.args_parse_ret = -1;
......
......@@ -97,6 +97,9 @@ struct GraphCommonOptimizeOptions {
bool fuse_conv_bias_with_z = false;
//! whether to enable fast-run profiled winograd opr replace
bool weight_winograd_transform = false;
//! whether to enable weight preprocess, if enabled it may use more
//! memory, default disable now
bool weight_preprocess = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format
......@@ -127,6 +130,7 @@ struct GraphCommonOptimizeOptions {
SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z);
SET(weight_winograd_transform);
SET(weight_preprocess);
#undef SET
#define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \
......
......@@ -963,6 +963,9 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight(
bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
const cg::OperatorNodeBase& opr) const {
if (!opr.owner_graph()->options().graph_opt.weight_preprocess) {
return false;
}
if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE))
return false;
if (cg::is_const_var_value(opr.input(1)))
......
......@@ -2225,6 +2225,7 @@ protected:
iw = ih;
comp_node = CompNode::load("cpux");
graph = ComputingGraph::make();
graph->options().graph_opt.weight_preprocess = is_weight_preprocess();
TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh};
x_host = std::make_shared<HostTensorND>(comp_node, x_shape);
auto x = opr::Host2DeviceCopy::make(*graph, x_host);
......@@ -2247,6 +2248,8 @@ protected:
void run() { func->execute().wait(); }
virtual bool is_weight_preprocess() { return true; }
void TearDown() override {
func.reset();
// Triggers mock check
......@@ -2346,6 +2349,33 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) {
}
}
class TestNoWeightPreprocess : public TestWeightPreprocess {
bool is_weight_preprocess() override { return false; }
};
TEST_F(TestNoWeightPreprocess, NoPreprocess) {
using ::testing::_;
using ::testing::Return;
auto& mock = mock_conv();
MockAlgorithm algo;
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _))
.WillRepeatedly(Return(&algo));
EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _))
.WillRepeatedly(Return(0));
EXPECT_CALL(mock, get_preprocess_workspace_in_bytes(_, _, _))
.WillRepeatedly(Return(0));
{
::testing::InSequence seq;
// Return empty preprocess filters, indicating no need to preprocess
EXPECT_CALL(mock, deduce_preprocessed_filter_layout(_, _, _)).Times(0);
EXPECT_CALL(mock, exec_preprocess(_, _, _, _, _)).Times(0);
EXPECT_CALL(mock, exec(_, _, _, nullptr, _));
run();
}
}
} // anonymous namespace
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册