提交 8f44d6ea 编写于 作者: M Megvii Engine Team

fix(mgb): fix optpass fail at transform NCHW to NCHW4 when input dtype is float

GitOrigin-RevId: 3c2c68b11cbf2dd28ae61745c26cc276e74c0761
上级 95eb6ae3
...@@ -1415,6 +1415,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, ...@@ -1415,6 +1415,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
return new_var; return new_var;
} }
//! FIXME: All float oprs do not support NCHW4. Supports it in the future plz.
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
MIDOUT_B("EnableNCHW4Pass::make") MIDOUT_B("EnableNCHW4Pass::make")
auto ret = std::make_unique<EnableNCHW4Pass>(); auto ret = std::make_unique<EnableNCHW4Pass>();
...@@ -1467,6 +1468,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1467,6 +1468,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
auto replace_conv_opr = [trans_nchw4, conv_format]( auto replace_conv_opr = [trans_nchw4, conv_format](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
if (conv_opr.param().format != if (conv_opr.param().format !=
...@@ -1503,6 +1508,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1503,6 +1508,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
src_to_nchw4_mode]( src_to_nchw4_mode](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& batch_conv_bias_opr = auto& batch_conv_bias_opr =
opr->cast_final_safe<opr::BatchConvBiasForward>(); opr->cast_final_safe<opr::BatchConvBiasForward>();
...@@ -1580,6 +1589,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1580,6 +1589,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
src_to_nchw4_mode]( src_to_nchw4_mode](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
if (conv_bias_opr.param().format != if (conv_bias_opr.param().format !=
...@@ -1647,6 +1660,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1647,6 +1660,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}; };
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false; bool has_inp_changed = false;
for (size_t i = 0; i < opr->input().size(); i++) { for (size_t i = 0; i < opr->input().size(); i++) {
...@@ -1691,6 +1708,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1691,6 +1708,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}; };
auto replace_pooling_opr = [](OperatorNodeBase* opr, auto replace_pooling_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
using Param = opr::PoolingForward::Param; using Param = opr::PoolingForward::Param;
using Format = Param::Format; using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
...@@ -1716,6 +1737,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1716,6 +1737,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}; };
auto replace_resize_opr = [](OperatorNodeBase* opr, auto replace_resize_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
using Param = opr::ResizeForward::Param; using Param = opr::ResizeForward::Param;
using Format = Param::Format; using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
...@@ -1738,6 +1763,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1738,6 +1763,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}; };
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, auto replace_warp_perspective_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
using Param = opr::WarpPerspective::Param; using Param = opr::WarpPerspective::Param;
using Format = Param::Format; using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
......
...@@ -2816,7 +2816,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { ...@@ -2816,7 +2816,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) {
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
} }
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt).param().format); find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}}) graph->compile({{y_opt, {}}})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册