提交 76c79d79 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix ConvBias not passing on prep_filter

GitOrigin-RevId: 0dc9f9d1331d48f6e8198472332316ef273f80e7
上级 498f7b2e
...@@ -1039,10 +1039,7 @@ void ConvolutionForward::init_output_format() { ...@@ -1039,10 +1039,7 @@ void ConvolutionForward::init_output_format() {
} }
void ConvolutionForward::scn_do_execute() { void ConvolutionForward::scn_do_execute() {
if (input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) && update_preprocessed_filter();
cg::is_const_var_value(input(1))) {
update_preprocessed_filter();
}
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
...@@ -1606,8 +1603,7 @@ void ConvBiasForward::scn_do_execute() { ...@@ -1606,8 +1603,7 @@ void ConvBiasForward::scn_do_execute() {
megdnn::TensorND z_tensor{nullptr, z_layout}; megdnn::TensorND z_tensor{nullptr, z_layout};
mo->exec(inp[0]->dev_tensor().as_megdnn(), mo->exec(inp[0]->dev_tensor().as_megdnn(),
inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor, inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
nullptr,
intl::get_megdnn_workspace_from_var(output().back())); intl::get_megdnn_workspace_from_var(output().back()));
} else if (inp.size() == 3) { } else if (inp.size() == 3) {
...@@ -1619,8 +1615,7 @@ void ConvBiasForward::scn_do_execute() { ...@@ -1619,8 +1615,7 @@ void ConvBiasForward::scn_do_execute() {
mo->exec(inp[0]->dev_tensor().as_megdnn(), mo->exec(inp[0]->dev_tensor().as_megdnn(),
inp[1]->dev_tensor().as_megdnn(), inp[1]->dev_tensor().as_megdnn(),
inp[2]->dev_tensor().as_megdnn(), z_tensor, inp[2]->dev_tensor().as_megdnn(), z_tensor,
output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
nullptr,
intl::get_megdnn_workspace_from_var(output().back())); intl::get_megdnn_workspace_from_var(output().back()));
} else { } else {
mgb_assert(inp.size() == 4); mgb_assert(inp.size() == 4);
...@@ -1628,8 +1623,7 @@ void ConvBiasForward::scn_do_execute() { ...@@ -1628,8 +1623,7 @@ void ConvBiasForward::scn_do_execute() {
inp[1]->dev_tensor().as_megdnn(), inp[1]->dev_tensor().as_megdnn(),
inp[2]->dev_tensor().as_megdnn(), inp[2]->dev_tensor().as_megdnn(),
inp[3]->dev_tensor().as_megdnn(), inp[3]->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
nullptr,
intl::get_megdnn_workspace_from_var(output().back())); intl::get_megdnn_workspace_from_var(output().back()));
} }
} }
...@@ -2389,4 +2383,4 @@ void BatchConvBiasForward::init_output_format() { ...@@ -2389,4 +2383,4 @@ void BatchConvBiasForward::init_output_format() {
#undef IMPL_CONV #undef IMPL_CONV
#undef MGB_FOREACH_FASTRUN_OPR #undef MGB_FOREACH_FASTRUN_OPR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册