提交 f470df4f 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix convbias with no bias when weight preprocess

GitOrigin-RevId: 8837ad539cf4bd9f99a546996ee91cad31a03766
上级 6c4841e8
...@@ -1976,7 +1976,6 @@ typename DnnOp::Algorithm* try_find_any_bias_preprocess_algo( ...@@ -1976,7 +1976,6 @@ typename DnnOp::Algorithm* try_find_any_bias_preprocess_algo(
bool valid = false; bool valid = false;
if (!layouts[1].is_empty()) { if (!layouts[1].is_empty()) {
valid = true; valid = true;
break;
} }
if (valid) { if (valid) {
found.emplace(true); found.emplace(true);
...@@ -2204,6 +2203,8 @@ TEST(TestGraph, FreeBias) { ...@@ -2204,6 +2203,8 @@ TEST(TestGraph, FreeBias) {
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1}); auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1});
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias); auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias);
auto w2 = mkcvar("w2", {32, 32, 1, 1});
auto conv2 = opr::ConvBias::make(conv1, w2, param_conv_bias);
Maybe<bool> wp1; Maybe<bool> wp1;
conv1.node()->owner_opr()->cast_final_safe<opr::ConvBias>() conv1.node()->owner_opr()->cast_final_safe<opr::ConvBias>()
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { .setup_algo_chooser([&](const cg::OperatorNodeBase* opr) {
...@@ -2216,7 +2217,7 @@ TEST(TestGraph, FreeBias) { ...@@ -2216,7 +2217,7 @@ TEST(TestGraph, FreeBias) {
}); });
HostTensorND host_y; HostTensorND host_y;
auto func =graph->compile({make_callback_copy(conv1, host_y)}); auto func =graph->compile({make_callback_copy(conv2, host_y)});
//!flag the no need memory of var //!flag the no need memory of var
func->execute(); func->execute();
//!free the no need memory of var //!free the no need memory of var
......
...@@ -960,11 +960,20 @@ void ConvBiasForward::scn_do_execute_preprocess() { ...@@ -960,11 +960,20 @@ void ConvBiasForward::scn_do_execute_preprocess() {
if (input().size() > 3) { if (input().size() > 3) {
z_layout = input(3)->layout(); z_layout = input(3)->layout();
} }
megdnn_opr()->exec_preprocess( if (input().size() > 2) {
input(0)->layout(), input(1)->dev_tensor().as_megdnn(), megdnn_opr()->exec_preprocess(
input(2)->dev_tensor().as_megdnn(), z_layout, output(0)->layout(), input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
preprocessed_filter(), input(2)->dev_tensor().as_megdnn(), z_layout,
intl::get_megdnn_workspace_from_var(output().back())); output(0)->layout(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
} else {
megdnn::TensorND bias_tensor{nullptr, bias_layout};
megdnn_opr()->exec_preprocess(
input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
bias_tensor, z_layout, output(0)->layout(),
preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
}
//! Flag the weight and bias no use later, which can be freed when no other //! Flag the weight and bias no use later, which can be freed when no other
//! var depend on its dev_value, host_value and shape. //! var depend on its dev_value, host_value and shape.
auto receiver_info_weight = auto receiver_info_weight =
...@@ -975,7 +984,7 @@ void ConvBiasForward::scn_do_execute_preprocess() { ...@@ -975,7 +984,7 @@ void ConvBiasForward::scn_do_execute_preprocess() {
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
} }
//! if bias is preprocessd //! if bias is preprocessd
if (input().size() > 3) { if (input().size() > 2) {
auto preprocessed_layouts = auto preprocessed_layouts =
megdnn_opr()->deduce_preprocessed_filter_layout( megdnn_opr()->deduce_preprocessed_filter_layout(
input(0)->layout(), input(1)->layout(), bias_layout, input(0)->layout(), input(1)->layout(), bias_layout,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册