From a605f38b2606788c36243ecf8664830de13bbd6b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 2 Aug 2021 15:56:17 +0800 Subject: [PATCH] refactor(opmeth): add OpMethCache struct GitOrigin-RevId: c1ebe156725236eda08971b0978eab7c93219953 --- imperative/python/src/common.cpp | 2 +- imperative/python/src/grad.cpp | 45 +++------- imperative/src/impl/op_def.cpp | 33 ++++++- imperative/src/impl/op_trait.h | 1 + imperative/src/impl/proxy_graph_detail.cpp | 39 +------- .../include/megbrain/imperative/graph_cache.h | 90 +++++++++++++++++++ 6 files changed, 137 insertions(+), 73 deletions(-) create mode 100644 imperative/src/include/megbrain/imperative/graph_cache.h diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 96d64b9e6..32a1cb68c 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -52,7 +52,7 @@ std::string get_default_device() { } void init_common(py::module m) { - auto&& PyCompNode = py::class_(m, "CompNode") + auto PyCompNode = py::class_(m, "CompNode") .def(py::init()) .def(py::init(py::overload_cast(&CompNode::load))) .def_property_readonly("logical_name", [](const CompNode& cn) { diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 991d80bd2..ba43bd8ec 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -34,53 +34,36 @@ struct GradSlotWeakPtr { size_t idx; }; -struct BackwardGraphCache : std::unordered_map>, CompNodeDepedentObject { - std::shared_ptr on_comp_node_finalize() override { - clear(); - return {}; - } -} backward_graph_cache; - std::shared_ptr make_backward_graph( ApplyContext& ctx, const apply_result_t& outputs) { // hash - static_assert(alignof(size_t) % alignof(bool) == 0); - size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); - alignas(alignof(size_t)) std::byte buf[buf_size]; - size_t* size_t_ptr = reinterpret_cast(buf); - bool* bool_ptr = reinterpret_cast(size_t_ptr + (1 + ctx.nargs * 2)); - bool* bool_ptr0 = bool_ptr; - *(size_t_ptr++) = ctx.op->hash(); + using OptimizedBackwardGraphCache = OpMethResultCache, SmallVector>; + thread_local OptimizedBackwardGraphCache cache; + decltype(cache)::key_t cache_key{ctx.op}; + SmallVector& input_descs = cache_key.inputs; + SmallVector& input_requires_grad = std::get<0>(cache_key.extras); + input_descs.resize(ctx.nargs); + input_requires_grad.resize(ctx.nargs); for (size_t i = 0; i < ctx.nargs; ++i) { - *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); - *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); - *(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); + input_descs[i].layout.dtype = ctx.args[i]->dtype(); + input_descs[i].comp_node = ctx.args[i]->comp_node(); + input_requires_grad[i] = python::input_requires_grad(ctx, i); } - mgb_assert(bool_ptr0 == reinterpret_cast(size_t_ptr) && - bool_ptr == reinterpret_cast(buf + buf_size)); - uint64_t key = XXHash{}.update(buf, buf_size).digest(); - auto&& iter = backward_graph_cache.find(key); - if (iter != backward_graph_cache.end()) { + auto iter = cache.find(cache_key); + if (iter != cache.end()) { return iter->second; } // slow path - SmallVector inputs(ctx.nargs); - SmallVector input_requires_grad(ctx.nargs, false); SmallVector output_has_grad(outputs.size(), true); - for (size_t i = 0; i < ctx.nargs; ++i) { - inputs[i].comp_node = ctx.args[i]->comp_node(); - inputs[i].layout.dtype = ctx.args[i]->dtype(); - input_requires_grad[i] = python::input_requires_grad(ctx, i); - } std::shared_ptr ret; auto bg = OpDef::make_backward_graph( - *ctx.op, inputs, input_requires_grad, output_has_grad); + *ctx.op, input_descs, input_requires_grad, output_has_grad); if (!bg.graph.empty()) { ret = std::make_shared(bg); } - backward_graph_cache.emplace(key, ret); + cache.emplace(cache_key, ret); return ret; } diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 2f856c02e..98ddaa39c 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -85,7 +85,14 @@ EncodedSubraph OpDef::make_backward_graph( const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { - return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); + using BackwardGraphCache = OpMethResultCache, SmallVector>; + thread_local BackwardGraphCache cache; + decltype(cache)::key_t cache_key{const_cast(def).shared_from_this(), inputs, {input_requires_grad, output_has_grad}}; + auto iter = cache.find(cache_key); + if (iter == cache.end()) { + iter = cache.insert({cache_key, def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad)}).first; + } + return iter->second; } std::vector> OpDef::props( @@ -94,7 +101,7 @@ std::vector> OpDef::props( } std::string OpDef::to_string() const { - std::string builder = "{"; + std::string builder = trait()->make_name(*this) + "{"; for (auto&& [name, value]: props(*this)) { builder += name; builder += ": "; @@ -170,7 +177,7 @@ std::string Subgraph::repr() const { if (auto* p = op->try_cast_final()) { buf << p->type; } else { - buf << op->dyn_typeinfo()->name; + buf << op->make_name(); } for (size_t i : ins) { buf << " "; @@ -196,6 +203,26 @@ std::string Subgraph::repr() const { return buf.str(); } +bool Subgraph::is_single() const { + if (exprs.size() != 1) { + return false; + } + auto& expr = exprs.at(0); + return expr.inputs == inputs && expr.outputs == outputs; +} + +std::shared_ptr Subgraph::as_single() const { + if (is_single()) { + return exprs.at(0).op; + } else { + return nullptr; + } +} + +bool Subgraph::operator==(const Subgraph& rhs) const { + mgb_assert(false, "Not Implemented"); +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 850f23466..7c00ae0d5 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain/imperative/op_def.h" +#include "megbrain/imperative/graph_cache.h" namespace mgb { namespace imperative { diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index b49e67206..7546d2901 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -113,49 +113,12 @@ void execute(const OpDef& def, // return graph->infer_output_attrs_fallible(def, inputs); // } -namespace { - -size_t get_backward_graph_hash_key(const OpDef& def, - const SmallVector& inputs, - const SmallVector& input_requires_grad, - const SmallVector& output_has_grad) { - XXHash state; - size_t length = 0, data[3 + 2 * inputs.size()]; - data[length ++] = def.hash(); - for (auto &&i : inputs) { - data[length ++] = mgb::hash(i.layout.dtype.handle()); - data[length ++] = mgb::hash(i.comp_node); - } - data[length ++] = mgb::hash(input_requires_grad); - data[length ++] = mgb::hash(output_has_grad); - mgb_assert(length == 3 + 2 * inputs.size()); - state.update(data, length * sizeof(size_t)); - return state.digest(); -} - -struct BackwardGraphCache : std::unordered_map, CompNodeDepedentObject { - std::shared_ptr on_comp_node_finalize() override { - clear(); - return {}; - } -} backward_graph_cache; - -} // anonymous namespace - EncodedSubraph make_backward_graph(const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { - auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); - auto&& iter = backward_graph_cache.find(hash_key); - if (iter != backward_graph_cache.end()) { - return iter->second; - } - auto&& graph = ProxyGraph::get_default_graph(); - auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); - backward_graph_cache.emplace(hash_key, res); - return res; + return ProxyGraph::get_default_graph()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); } } // namespace proxy_graph_detail diff --git a/imperative/src/include/megbrain/imperative/graph_cache.h b/imperative/src/include/megbrain/imperative/graph_cache.h new file mode 100644 index 000000000..153e60ab7 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/graph_cache.h @@ -0,0 +1,90 @@ +/** + * \file imperative/src/include/megbrain/imperative/graph_builder.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain/imperative/subgraph.h" +#include "megbrain/imperative/op_def.h" + +namespace mgb { +namespace imperative { + +template +struct OpMethArgs { + std::shared_ptr op; + SmallVector inputs; + std::tuple extras; + + size_t hash() const; + bool operator==(const OpMethArgs& rhs) const { + if (bool(op) ^ bool(rhs.op)) { + return false; + } + if (op && rhs.op && !op->is_same(*rhs.op)) { + return false; + } + if (inputs.size() != rhs.inputs.size()) { + return false; + } + size_t nr_inputs = inputs.size(); + for (size_t i = 0; i < nr_inputs; ++i) { + if (inputs[i].comp_node != rhs.inputs[i].comp_node) { + return false; + } + if (inputs[i].layout.dtype != rhs.inputs[i].layout.dtype) { + return false; + } + } + return extras == rhs.extras; + } + + struct hash_t { + size_t operator()(const OpMethArgs& key) const { + return key.hash(); + } + }; +}; + +template +inline size_t OpMethArgs::hash() const { + XXHash state; + size_t length = 0; + size_t data[1 + 2 * inputs.size() + sizeof...(TExtraArgs)]; + auto append = [&](size_t hash) { + data[length++] = hash; + }; + append(op->hash()); + for (auto &&i : inputs) { + append(mgb::hash(i.layout.dtype.handle())); + append(mgb::hash(i.comp_node)); + } + std::apply([&](auto&&... extras){ + (append(mgb::hash(extras)), ...); + }, extras); + mgb_assert(length == sizeof(data) / sizeof(size_t)); + state.update(data, sizeof(data)); + return state.digest(); +} + +template +struct OpMethResultCache : std::unordered_map, TValue, typename OpMethArgs::hash_t>, CompNodeDepedentObject { + std::shared_ptr on_comp_node_finalize() override { + static_cast, TValue, typename OpMethArgs::hash_t>*>(this)->clear(); + // clear(); + return {}; + } + + using key_t = OpMethArgs; +}; + +} // namespace imperative +} // namespace mgb -- GitLab