diff --git a/src/pass/broadcast_rewrite.cc b/src/pass/broadcast_rewrite.cc index 34aff5956d4dcffc07e46820e82e997e8f316a1d..8472641f965b425ab1f7b937df97c83e0264ed65 100644 --- a/src/pass/broadcast_rewrite.cc +++ b/src/pass/broadcast_rewrite.cc @@ -80,7 +80,7 @@ class BroadcastVecRewriter : public IRMutator { } } - if (secDimFit && dimFit && forLoopFit) { + if (secDimFit && dimFit && forLoopFit && remainDim > 1) { newExtent = GetInt32Const(GetItem(forInfo.ops_, idx).as()->extent) / remainDim; tmpBuffer = VarExpr("tmp_broadcast_" + std::to_string(broadBufferCount++) + "_local_UB", dtype); varList = dstInfo->var_;