提交 af828ca9 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): fix folding conv dimshuffle pass

GitOrigin-RevId: 756878c173787481761a48c8d0550aa409702b3b
上级 c67c4b7d
...@@ -240,11 +240,30 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { ...@@ -240,11 +240,30 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
&readers](OperatorNodeBase* opr) { &readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set; ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set; ThinHashSet<OperatorNodeBase*> reader_set;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
auto in_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm);
if (!is_s82s4)
return false;
opr_set.insert(typecvt);
// check reshape // check reshape
auto reshape = try_cast_as_op<opr::Reshape>(opr); auto reshape =
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr());
if (reshape == nullptr) if (reshape == nullptr)
return false; return false;
opr_set.insert(opr); opr_set.insert(reshape);
for (auto&& i : readers[reshape]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check dimshuffle // check dimshuffle
auto shuffle = auto shuffle =
...@@ -267,27 +286,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { ...@@ -267,27 +286,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
} }
} }
auto typecvt =
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr());
if (typecvt == nullptr)
return false;
auto in_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm);
if (!is_s82s4)
return false;
opr_set.insert(typecvt);
for (auto&& i : readers[typecvt]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias // check conv bias
auto conv_bias = auto conv_bias =
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr) if (conv_bias == nullptr)
return false; return false;
auto inp_dtype = conv_bias->input(0)->dtype(); auto inp_dtype = conv_bias->input(0)->dtype();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册