From 58ebb26156c2b3ce448aca1c61b545dbd6e3256a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Apr 2021 14:32:32 +0800 Subject: [PATCH] fix(imperative/tensor): fix ConstTensorCache GitOrigin-RevId: 0767bcfa281dc10969c320c5717dabcb9b60c15f --- imperative/src/impl/physical_tensor.cpp | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index 2fb720109..3a895afe8 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -125,6 +125,7 @@ public: size_t size; BlobPtr blob; + Entry() = default; Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_) : data(new dt_byte[size_]), size(size_), blob(blob_) { memcpy(data.get(), ptr, size); @@ -136,6 +137,8 @@ public: } }; + using KV = std::pair; + bool check(const HostTensorND& hv) { auto&& layout = hv.layout(); auto&& span = layout.span(); @@ -190,7 +193,7 @@ public: } std::mutex mtx; - size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536; + const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536; private: void maybe_collect_g0() { @@ -200,25 +203,37 @@ private: } } void maybe_collect_g1() { - if (g1.size() <= hwm) return; + if (g1.size() < hwm) return; - using KV = std::pair; - std::vector tmp; - tmp.reserve(g1.size()); + tmp.clear(); for (auto&& kv : g1) { tmp.emplace_back(kv.first, std::move(kv.second)); } std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) { return lhs.second.hitcnt > rhs.second.hitcnt; }); + tmp.resize(lwm); g1.clear(); for (auto&& kv : tmp) { kv.second.hitcnt = 0; g1.emplace(std::move(kv)); } } + + // g0: records blobs which have been seen at least once (within a window) + // g0b: backup of g0 + // g1: records the most frequently used blobs which have been seen at least + // twice. When `g1.size() == hwm`, it will be refreshed and only the top + // `lhw` frequently used blobs will be kept. std::unordered_set g0, g0b; std::unordered_map g1; + std::vector tmp; + +public: + ConstTensorCache() { + g0.reserve(window), g0b.reserve(window); + g1.reserve(hwm), tmp.reserve(hwm); + } }; struct MultiCNConstTensorCache : CompNodeDepedentObject { -- GitLab