diff --git a/imperative/src/include/megbrain/imperative/graph_cache.h b/imperative/src/include/megbrain/imperative/graph_cache.h index 073ae924dff26289bce70b234b6dfa0d6c4d66ad..a12293ea3c42ef0a9ee578c4cc8316b769fdcac1 100644 --- a/imperative/src/include/megbrain/imperative/graph_cache.h +++ b/imperative/src/include/megbrain/imperative/graph_cache.h @@ -44,6 +44,12 @@ struct OpMethArgs { if (inputs[i].layout.dtype != rhs.inputs[i].layout.dtype) { return false; } + if (inputs[i].layout.ndim != rhs.inputs[i].layout.ndim) { + return false; + } + if (inputs[i].value.empty() != rhs.inputs[i].value.empty()) { + return false; + } } return extras == rhs.extras; } @@ -57,12 +63,14 @@ template inline size_t OpMethArgs::hash() const { XXHash state; size_t length = 0; - size_t data[1 + 2 * inputs.size() + sizeof...(TExtraArgs)]; + size_t data[1 + 4 * inputs.size() + sizeof...(TExtraArgs)]; auto append = [&](size_t hash) { data[length++] = hash; }; append(op->hash()); for (auto&& i : inputs) { append(mgb::hash(i.layout.dtype.handle())); append(mgb::hash(i.comp_node)); + append(mgb::hash(i.layout.ndim)); + append(mgb::hash(i.value.empty())); } std::apply([&](auto&&... extras) { (append(mgb::hash(extras)), ...); }, extras); mgb_assert(length == sizeof(data) / sizeof(size_t));