提交 a605f38b 编写于 作者: M Megvii Engine Team

refactor(opmeth): add OpMethCache struct

GitOrigin-RevId: c1ebe156725236eda08971b0978eab7c93219953
上级 0213dbe5
...@@ -52,7 +52,7 @@ std::string get_default_device() { ...@@ -52,7 +52,7 @@ std::string get_default_device() {
} }
void init_common(py::module m) { void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") auto PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init()) .def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) .def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def_property_readonly("logical_name", [](const CompNode& cn) { .def_property_readonly("logical_name", [](const CompNode& cn) {
......
...@@ -34,53 +34,36 @@ struct GradSlotWeakPtr { ...@@ -34,53 +34,36 @@ struct GradSlotWeakPtr {
size_t idx; size_t idx;
}; };
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) { ApplyContext& ctx, const apply_result_t& outputs) {
// hash // hash
static_assert(alignof(size_t) % alignof(bool) == 0); using OptimizedBackwardGraphCache = OpMethResultCache<std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); thread_local OptimizedBackwardGraphCache cache;
alignas(alignof(size_t)) std::byte buf[buf_size]; decltype(cache)::key_t cache_key{ctx.op};
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf); SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2)); SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras);
bool* bool_ptr0 = bool_ptr; input_descs.resize(ctx.nargs);
*(size_t_ptr++) = ctx.op->hash(); input_requires_grad.resize(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); input_descs[i].layout.dtype = ctx.args[i]->dtype();
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); input_descs[i].comp_node = ctx.args[i]->comp_node();
*(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); input_requires_grad[i] = python::input_requires_grad(ctx, i);
} }
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
uint64_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key); auto iter = cache.find(cache_key);
if (iter != backward_graph_cache.end()) { if (iter != cache.end()) {
return iter->second; return iter->second;
} }
// slow path // slow path
SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
SmallVector<bool> input_requires_grad(ctx.nargs, false);
SmallVector<bool> output_has_grad(outputs.size(), true); SmallVector<bool> output_has_grad(outputs.size(), true);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = python::input_requires_grad(ctx, i);
}
std::shared_ptr<OptimizedBackwardGraphResult> ret; std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = OpDef::make_backward_graph( auto bg = OpDef::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad); *ctx.op, input_descs, input_requires_grad, output_has_grad);
if (!bg.graph.empty()) { if (!bg.graph.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
} }
backward_graph_cache.emplace(key, ret); cache.emplace(cache_key, ret);
return ret; return ret;
} }
......
...@@ -85,7 +85,14 @@ EncodedSubraph OpDef::make_backward_graph( ...@@ -85,7 +85,14 @@ EncodedSubraph OpDef::make_backward_graph(
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) { const SmallVector<bool>& output_has_grad) {
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); using BackwardGraphCache = OpMethResultCache<EncodedSubraph, SmallVector<bool>, SmallVector<bool>>;
thread_local BackwardGraphCache cache;
decltype(cache)::key_t cache_key{const_cast<OpDef&>(def).shared_from_this(), inputs, {input_requires_grad, output_has_grad}};
auto iter = cache.find(cache_key);
if (iter == cache.end()) {
iter = cache.insert({cache_key, def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad)}).first;
}
return iter->second;
} }
std::vector<std::pair<const char*, std::string>> OpDef::props( std::vector<std::pair<const char*, std::string>> OpDef::props(
...@@ -94,7 +101,7 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( ...@@ -94,7 +101,7 @@ std::vector<std::pair<const char*, std::string>> OpDef::props(
} }
std::string OpDef::to_string() const { std::string OpDef::to_string() const {
std::string builder = "{"; std::string builder = trait()->make_name(*this) + "{";
for (auto&& [name, value]: props(*this)) { for (auto&& [name, value]: props(*this)) {
builder += name; builder += name;
builder += ": "; builder += ": ";
...@@ -170,7 +177,7 @@ std::string Subgraph::repr() const { ...@@ -170,7 +177,7 @@ std::string Subgraph::repr() const {
if (auto* p = op->try_cast_final<OprAttr>()) { if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type; buf << p->type;
} else { } else {
buf << op->dyn_typeinfo()->name; buf << op->make_name();
} }
for (size_t i : ins) { for (size_t i : ins) {
buf << " "; buf << " ";
...@@ -196,6 +203,26 @@ std::string Subgraph::repr() const { ...@@ -196,6 +203,26 @@ std::string Subgraph::repr() const {
return buf.str(); return buf.str();
} }
bool Subgraph::is_single() const {
if (exprs.size() != 1) {
return false;
}
auto& expr = exprs.at(0);
return expr.inputs == inputs && expr.outputs == outputs;
}
std::shared_ptr<OpDef> Subgraph::as_single() const {
if (is_single()) {
return exprs.at(0).op;
} else {
return nullptr;
}
}
bool Subgraph::operator==(const Subgraph& rhs) const {
mgb_assert(false, "Not Implemented");
}
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/graph_cache.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
......
...@@ -113,49 +113,12 @@ void execute(const OpDef& def, ...@@ -113,49 +113,12 @@ void execute(const OpDef& def,
// return graph->infer_output_attrs_fallible(def, inputs); // return graph->infer_output_attrs_fallible(def, inputs);
// } // }
namespace {
size_t get_backward_graph_hash_key(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
XXHash state;
size_t length = 0, data[3 + 2 * inputs.size()];
data[length ++] = def.hash();
for (auto &&i : inputs) {
data[length ++] = mgb::hash(i.layout.dtype.handle());
data[length ++] = mgb::hash(i.comp_node);
}
data[length ++] = mgb::hash(input_requires_grad);
data[length ++] = mgb::hash(output_has_grad);
mgb_assert(length == 3 + 2 * inputs.size());
state.update(data, length * sizeof(size_t));
return state.digest();
}
struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
} // anonymous namespace
EncodedSubraph EncodedSubraph
make_backward_graph(const OpDef& def, make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) { const SmallVector<bool>& output_has_grad) {
auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); return ProxyGraph::get_default_graph()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
auto&& iter = backward_graph_cache.find(hash_key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
auto&& graph = ProxyGraph::get_default_graph();
auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
backward_graph_cache.emplace(hash_key, res);
return res;
} }
} // namespace proxy_graph_detail } // namespace proxy_graph_detail
......
/**
* \file imperative/src/include/megbrain/imperative/graph_builder.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/op_def.h"
namespace mgb {
namespace imperative {
template <typename... TExtraArgs>
struct OpMethArgs {
std::shared_ptr<OpDef> op;
SmallVector<LogicalTensorDesc> inputs;
std::tuple<TExtraArgs...> extras;
size_t hash() const;
bool operator==(const OpMethArgs& rhs) const {
if (bool(op) ^ bool(rhs.op)) {
return false;
}
if (op && rhs.op && !op->is_same(*rhs.op)) {
return false;
}
if (inputs.size() != rhs.inputs.size()) {
return false;
}
size_t nr_inputs = inputs.size();
for (size_t i = 0; i < nr_inputs; ++i) {
if (inputs[i].comp_node != rhs.inputs[i].comp_node) {
return false;
}
if (inputs[i].layout.dtype != rhs.inputs[i].layout.dtype) {
return false;
}
}
return extras == rhs.extras;
}
struct hash_t {
size_t operator()(const OpMethArgs& key) const {
return key.hash();
}
};
};
template <typename... TExtraArgs>
inline size_t OpMethArgs<TExtraArgs...>::hash() const {
XXHash state;
size_t length = 0;
size_t data[1 + 2 * inputs.size() + sizeof...(TExtraArgs)];
auto append = [&](size_t hash) {
data[length++] = hash;
};
append(op->hash());
for (auto &&i : inputs) {
append(mgb::hash(i.layout.dtype.handle()));
append(mgb::hash(i.comp_node));
}
std::apply([&](auto&&... extras){
(append(mgb::hash(extras)), ...);
}, extras);
mgb_assert(length == sizeof(data) / sizeof(size_t));
state.update(data, sizeof(data));
return state.digest();
}
template <typename TValue, typename... TExtraArgs>
struct OpMethResultCache : std::unordered_map<OpMethArgs<TExtraArgs...>, TValue, typename OpMethArgs<TExtraArgs...>::hash_t>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
static_cast<std::unordered_map<OpMethArgs<TExtraArgs...>, TValue, typename OpMethArgs<TExtraArgs...>::hash_t>*>(this)->clear();
// clear();
return {};
}
using key_t = OpMethArgs<TExtraArgs...>;
};
} // namespace imperative
} // namespace mgb
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册