提交 8f44d6ea 编写于 作者: M Megvii Engine Team

fix(mgb): fix optpass fail at transform NCHW to NCHW4 when input dtype is float

GitOrigin-RevId: 3c2c68b11cbf2dd28ae61745c26cc276e74c0761
上级 95eb6ae3
......@@ -1415,6 +1415,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
return new_var;
}
//! FIXME: All float oprs do not support NCHW4. Supports it in the future plz.
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
MIDOUT_B("EnableNCHW4Pass::make")
auto ret = std::make_unique<EnableNCHW4Pass>();
......@@ -1467,6 +1468,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
auto replace_conv_opr = [trans_nchw4, conv_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
if (conv_opr.param().format !=
......@@ -1503,6 +1508,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
auto& batch_conv_bias_opr =
opr->cast_final_safe<opr::BatchConvBiasForward>();
......@@ -1580,6 +1589,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
if (conv_bias_opr.param().format !=
......@@ -1647,6 +1660,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
for (size_t i = 0; i < opr->input().size(); i++) {
......@@ -1691,6 +1708,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_pooling_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
using Param = opr::PoolingForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
......@@ -1716,6 +1737,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_resize_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
using Param = opr::ResizeForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
......@@ -1738,6 +1763,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
using Param = opr::WarpPerspective::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
......@@ -3127,4 +3156,4 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const {
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -2816,7 +2816,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) {
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册