提交 f470df4f 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix convbias with no bias when weight preprocess

GitOrigin-RevId: 8837ad539cf4bd9f99a546996ee91cad31a03766
上级 6c4841e8
......@@ -1976,7 +1976,6 @@ typename DnnOp::Algorithm* try_find_any_bias_preprocess_algo(
bool valid = false;
if (!layouts[1].is_empty()) {
valid = true;
break;
}
if (valid) {
found.emplace(true);
......@@ -2204,6 +2203,8 @@ TEST(TestGraph, FreeBias) {
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1});
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias);
auto w2 = mkcvar("w2", {32, 32, 1, 1});
auto conv2 = opr::ConvBias::make(conv1, w2, param_conv_bias);
Maybe<bool> wp1;
conv1.node()->owner_opr()->cast_final_safe<opr::ConvBias>()
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) {
......@@ -2216,7 +2217,7 @@ TEST(TestGraph, FreeBias) {
});
HostTensorND host_y;
auto func =graph->compile({make_callback_copy(conv1, host_y)});
auto func =graph->compile({make_callback_copy(conv2, host_y)});
//!flag the no need memory of var
func->execute();
//!free the no need memory of var
......
......@@ -960,11 +960,20 @@ void ConvBiasForward::scn_do_execute_preprocess() {
if (input().size() > 3) {
z_layout = input(3)->layout();
}
megdnn_opr()->exec_preprocess(
input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), z_layout, output(0)->layout(),
preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
if (input().size() > 2) {
megdnn_opr()->exec_preprocess(
input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), z_layout,
output(0)->layout(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
} else {
megdnn::TensorND bias_tensor{nullptr, bias_layout};
megdnn_opr()->exec_preprocess(
input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
bias_tensor, z_layout, output(0)->layout(),
preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
}
//! Flag the weight and bias no use later, which can be freed when no other
//! var depend on its dev_value, host_value and shape.
auto receiver_info_weight =
......@@ -975,7 +984,7 @@ void ConvBiasForward::scn_do_execute_preprocess() {
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
}
//! if bias is preprocessd
if (input().size() > 3) {
if (input().size() > 2) {
auto preprocessed_layouts =
megdnn_opr()->deduce_preprocessed_filter_layout(
input(0)->layout(), input(1)->layout(), bias_layout,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册