提交 fea46ea9 编写于 作者: M Megvii Engine Team

perf(imperative): add opr cache for apply_on_physical_tensor

GitOrigin-RevId: fc5d5fb34d2379905e1130d9b0572ba596fb9fe4
上级 ea4e6ab9
......@@ -535,8 +535,9 @@ CompNode _get_device(PyObject* const* args, size_t nargs) {
->m_node->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf(
"ambiguous device: %s vs %s", cn.to_string().c_str(),
cn1.to_string().c_str()));
"ambiguous device: %s (from %s) vs %s (from %s)",
cn.to_string().c_str(), cn.to_string_logical().c_str(),
cn1.to_string().c_str(), cn1.to_string_logical().c_str()));
}
}
}
......
......@@ -11,8 +11,9 @@
#include "megbrain/graph/operator_node.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/ops/autogen.h"
#include "../blob_manager_impl.h"
#include "./common.h"
#include "./proxy_graph_base.h"
......@@ -80,6 +81,20 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>;
template <typename T>
TensorAdaptor(T*) -> TensorAdaptor<T, void>;
SmallVector<Tensor*> to_raw_ptr_array(
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
SmallVector<Tensor*> ret;
for (auto&& i : inputs) {
mgb_assert(i);
ret.push_back(i.get());
if (ensure_storage) {
// apply lazy allocation
i->blob()->storage();
}
}
return ret;
}
// single opr graph, for static inference and execution
// contains static inference descs
class ProxyGraph::MiniGraph {
......@@ -146,6 +161,9 @@ protected:
virtual const DeviceTensorND* infer_value_fallible(VarNode*) { mgb_assert(0); }
};
size_t buf_size;
SmallVector<size_t> hash_buf;
OperatorNodeBase* m_opr = nullptr;
SmallVector<std::unique_ptr<OperatorNodeBase>> opr_ref_keeper;
......@@ -194,6 +212,7 @@ protected:
return nullptr;
}
}
return &storage.value();
} else {
auto& value = tensor.value();
return value.shape_valid() ? &value : nullptr;
......@@ -203,8 +222,10 @@ protected:
public:
template <typename I, typename G>
MiniGraph(G& graph, const OpDef& opdef, const I& inputs)
: input_value_storage(inputs.size()) {
MiniGraph(
G& graph, const OpDef& opdef, const I& inputs, const size_t* hash_buf_,
const size_t buf_size_)
: buf_size(buf_size_), input_value_storage(inputs.size()) {
mgb_assert(!m_opr);
auto _ = graph.scoped_attach(this);
cg::VarNodeArray vinputs(inputs.size());
......@@ -222,7 +243,8 @@ public:
}
m_opr->init_output_static_infer_desc();
// fix permuted input
// fix permuted input: the order of m_opr->input() and vinputs may be
// different, input_remap keeps the index map of m_opr->input() and vinputs
input_remap.reserve(m_opr->input().size());
for (auto* v : m_opr->input()) {
auto [found, i] = find_index(vinputs, v);
......@@ -248,6 +270,23 @@ public:
mgb_assert(found);
output_remap.push_back(i);
}
hash_buf.resize(buf_size);
for (size_t i = 0; i < buf_size; ++i) {
hash_buf[i] = hash_buf_[i];
}
}
bool is_same_buf(const size_t hash_buf_[], const size_t buf_size_) {
if (buf_size != buf_size_) {
return false;
}
for (size_t i = 0; i < buf_size; i++) {
if (hash_buf[i] != hash_buf_[i]) {
return false;
}
}
return true;
}
// methods for containing graph
......@@ -264,6 +303,87 @@ public:
return m_opr;
}
void init_input_tensor(const SmallVector<Tensor*>& inputs) {
auto&& opr_inputs = m_opr->input();
mgb_assert(opr_inputs.size() == inputs.size());
size_t idx = 0;
for (auto&& input : opr_inputs) {
mgb_assert(input->owner_opr()->same_type<InputPlaceholder>());
input->m_dev_tensor.storage({});
auto&& dev_tensor = inputs[input_remap[idx]]->dev_tensor();
auto&& layout = dev_tensor.layout();
input->shape(dev_tensor.shape());
auto&& chk = input->m_mem_plan.reset_from_owner_var().chunk();
input->m_dev_tensor.reset(dev_tensor.storage(), layout);
input->m_mem_plan.layout(layout);
chk.mem_alloc_status.set_from_owner_var();
mgb_assert(input->comp_node() == dev_tensor.comp_node());
mgb_assert(input->shape().eq_shape(layout));
mgb_assert(input->dtype() == layout.dtype);
idx++;
}
}
void init_output_tensor(const SmallVector<Tensor*>& outputs) {
size_t idx = 0;
mgb_assert(m_opr->usable_output().size() == outputs.size());
for (auto&& var : m_opr->output()) {
auto&& chk = var->m_mem_plan.reset_from_owner_var().chunk();
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
// alloc workspace
TensorLayout layout{var->shape(), var->dtype(), var->format()};
var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(
var->comp_node(), layout);
} else {
mgb_assert(idx < outputs.size());
auto&& tensor = outputs[idx];
auto&& layout = tensor->layout();
mgb_assert(var->comp_node() == tensor->comp_node());
mgb_assert(var->shape().eq_shape(layout));
mgb_assert(var->dtype() == layout.dtype);
if (!tensor->layout().is_empty()) {
var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
} else {
var->m_dev_tensor.storage({var->comp_node()});
}
++idx;
}
chk.mem_alloc_status.set_from_owner_var();
}
mgb_assert(idx == outputs.size());
// Memory forwarding was bypassed in megbrain with graph option
// imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
// to initialize some opr(e.g. Subtensor)'s internal state
// TODO: implement memory forwarding
m_opr->mem_plan_fwd_in2out_readonly();
{
// some opr (e.g. Reduce) rely on on_mem_status_changed to set
// input/output tensor corretly, since we bypass var_node_mem_mgr
// on_mem_status_changed should be called here
auto&& cb = m_opr->get_opr_event_callback().on_mem_status_changed;
if (cb.valid()) {
cb.val()();
}
}
}
void execute(
const SmallVector<Tensor*>& inputs, const SmallVector<Tensor*>& outputs,
cg::GraphExecutable::ExecEnv& env) {
init_input_tensor(inputs);
init_output_tensor(outputs);
m_opr->execute(env);
for (auto&& i : m_opr->input()) {
i->m_dev_tensor.storage({});
}
for (auto&& i : m_opr->output()) {
i->m_dev_tensor.storage({});
}
}
void register_shape_infer(
VarNode* varnode, const cg::static_infer::ShapeInferDesc& desc) {
auto [found, i] = find_index(m_opr->output(), varnode);
......@@ -278,15 +398,22 @@ public:
output_data[i].value_infer.initialize(m_opr, desc.deps, desc.infer_func);
}
const TensorShape& infer_shape(VarNode* var) { return m_sess->infer_shape(var); }
const TensorShape& infer_shape(VarNode* var) {
mgb_assert(m_sess);
return m_sess->infer_shape(var);
}
const DeviceTensorND& infer_value(VarNode* var) { return m_sess->infer_value(var); }
const DeviceTensorND& infer_value(VarNode* var) {
mgb_assert(m_sess);
return m_sess->infer_value(var);
}
OperatorNodeBase* opr() { return m_opr; }
// inference routine template for type of input
template <typename I>
class InferSession : protected InferSessionBase {
public:
MiniGraph& owner;
SmallVector<OutputData>& output_data;
InputAdaptor<I> inputs;
......@@ -355,7 +482,7 @@ public:
auto [found, i] = find_index(owner.m_opr->input(), var);
mgb_assert(found);
i = owner.input_remap[i];
auto* value = inputs.value(i, false);
auto* value = inputs.value(i, true);
mgb_assert(value);
return *value;
}
......@@ -379,12 +506,18 @@ public:
const TensorShape* infer_shape(size_t i, bool sync) {
i = owner.output_remap[i];
return infer(output_data[i].shape_infer, sync);
auto* p = infer(output_data[i].shape_infer, sync);
if (sync)
mgb_assert(p, "failed to infer shape");
return p;
}
const DeviceTensorND* infer_value(size_t i, bool sync) {
i = owner.output_remap[i];
return infer(output_data[i].shape_infer, sync);
auto* p = infer(output_data[i].value_infer, sync);
if (sync)
mgb_assert(p, "failed to infer value");
return p;
}
};
......@@ -499,10 +632,12 @@ class ProxyGraphTypeI : public ProxyGraphBase {
public:
void register_shape_infer(
VarNode* var, const cg::static_infer::ShapeInferDesc& desc) override {
mgb_assert(target);
target->register_shape_infer(var, desc);
};
void register_value_infer(
VarNode* var, const cg::static_infer::ValueInferDesc& desc) override {
mgb_assert(target);
target->register_value_infer(var, desc);
};
cg::static_infer::InferType get_infer_type(VarNode*) override {
......@@ -511,17 +646,22 @@ class ProxyGraphTypeI : public ProxyGraphBase {
}
// some poorly written inference func would call infer_{shape,value}
const TensorShape& infer_shape(VarNode* var) override {
mgb_assert(target);
return target->infer_shape(var);
}
const DeviceTensorND& infer_value(VarNode* var) override {
mgb_assert(target);
return target->infer_value(var);
}
};
ProxyGraph::MiniGraph* target = nullptr;
StaticInferManager m_static_infer_manager;
std::unordered_map<size_t, ProxyGraph::MiniGraph> m_mini_graph_cache;
std::unordered_multimap<size_t, ProxyGraph::MiniGraph> m_mini_graph_cache;
std::mutex m_mini_graph_cache_mtx;
size_t opr_count = 0;
ExecEnvBase m_env;
CompNode::UnorderedSet m_used_comp_node;
static thread_local std::unique_ptr<ProxyGraphTypeI> sm_instance;
......@@ -531,8 +671,12 @@ class ProxyGraphTypeI : public ProxyGraphBase {
size_t next_node_id() override { return opr_count; }
void add_used_comp_node(CompNode cn) { m_used_comp_node.insert(cn); }
std::shared_ptr<void> on_comp_node_finalize() override {
sm_instance.reset();
assert(!target);
MGB_LOCK_GUARD(m_mini_graph_cache_mtx);
m_mini_graph_cache.clear();
return {};
}
......@@ -575,38 +719,62 @@ class ProxyGraphTypeI : public ProxyGraphBase {
}
public:
~ProxyGraphTypeI() {
if (is_finalized()) {
return;
}
for (auto&& i : m_used_comp_node) {
if (i.device_type() == CompNode::DeviceType::CUDA)
continue;
i.sync();
}
}
OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override {
mgb_assert(target);
return target->insert_opr(std::move(opr_uniqp));
}
static ProxyGraphTypeI& inst() {
if (!sm_instance) {
if (!sm_instance || sm_instance->is_finalized()) {
sm_instance.reset(new ProxyGraphTypeI);
}
return *sm_instance;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
template <typename T>
ProxyGraph::MiniGraph& get_cached_minigraph(const OpDef& def, const T& inputs) {
mgb_assert(!is_finalized());
size_t buf_size = 2 * inputs.size() + 1;
size_t buf[buf_size];
size_t pos = 0;
buf[pos++] = def.hash();
for (auto&& desc : inputs) {
buf[pos++] = mgb::hash(desc.layout.dtype.handle());
buf[pos++] = mgb::hash(desc.comp_node);
for (auto&& inp : inputs) {
auto tensor = TensorAdaptor(inp);
buf[pos++] = mgb::hash(tensor.dtype().handle());
buf[pos++] = mgb::hash(tensor.comp_node());
}
mgb_assert(pos == buf_size);
auto key = XXHash{}.update(buf, buf_size * sizeof(size_t)).digest();
auto it = m_mini_graph_cache.find(key);
if (it == m_mini_graph_cache.end()) {
auto&& result = m_mini_graph_cache.emplace(
std::piecewise_construct, std::make_tuple(key),
std::forward_as_tuple(*this, def, inputs));
mgb_assert(result.second);
it = result.first;
}
auto& minigraph = it->second;
auto its = m_mini_graph_cache.equal_range(key);
auto it = its.first;
for (; it != its.second; ++it) {
if (it->second.is_same_buf(buf, buf_size)) {
return it->second;
}
mgb_log_warn("hash collision occurs in minigraph cache with key: %lu", key);
}
auto&& result = m_mini_graph_cache.emplace(
std::piecewise_construct, std::make_tuple(key),
std::forward_as_tuple(
*this, def, inputs, static_cast<size_t*>(buf), buf_size));
mgb_assert(result->first);
return result->second;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto& minigraph = get_cached_minigraph(def, inputs);
auto _ = scoped_attach(&minigraph);
auto sess = minigraph.infer_session(inputs);
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
......@@ -627,6 +795,88 @@ public:
}
return ret;
}
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def, const SmallVector<Tensor*>& inputs) {
SmallVector<LogicalTensorDesc> descs;
auto& minigraph = get_cached_minigraph(def, inputs);
auto _ = scoped_attach(&minigraph);
auto sess = minigraph.infer_session(inputs);
// some output var in minigraph.opr()->output() may not appears in
// minigraph.opr()->usable_output() bug execution may use the attrs for those
// output var, so we infer attrs for all outputs, but only return
// LogicalTensorDesc for minigraph.opr()->usable_output()
for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) {
auto* shape = sess.infer(sess.output_data[i].shape_infer, true);
mgb_assert(shape);
minigraph.opr()->output()[i]->shape(*shape);
}
descs.reserve(minigraph.output_size());
for (size_t i = 0; i < minigraph.output_size(); ++i) {
auto* ovar = minigraph.output_var(i);
descs.emplace_back();
auto& desc = descs.back();
desc.layout.dtype = ovar->dtype();
desc.comp_node = ovar->comp_node();
mgb_assert(ovar->dtype().valid() && ovar->comp_node().valid());
mgb_assert(
ovar->shape().ndim ||
ovar->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC));
desc.layout.init_contiguous_stride(ovar->shape());
}
return descs;
}
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
return infer_output_attrs(def, to_raw_ptr_array(inputs));
}
void exec(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
auto raw_inputs = to_raw_ptr_array(inputs),
raw_outputs = to_raw_ptr_array(outputs);
CompNode::UnorderedSet used_cns;
for (auto&& out : raw_outputs) {
auto cn = out->comp_node();
add_used_comp_node(cn);
if (used_cns.insert(cn).second) {
for (auto&& in : inputs) {
if (in->comp_node() != cn) {
auto&& e = in->get_or_create_event();
e->device_wait_by(cn);
}
}
}
}
auto& minigraph = get_cached_minigraph(def, raw_inputs);
auto _ = scoped_attach(&minigraph);
// some opr (e.g. Subtensor) may invoke infer_value during execution,
// so we need create inference session here
auto sess = minigraph.infer_session(raw_inputs);
minigraph.execute(raw_inputs, raw_outputs, m_env);
for (auto&& cn : used_cns) {
for (auto&& in : inputs) {
if (in->comp_node() != cn) {
in->add_release_callback(cn);
}
}
}
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs) {
auto&& raw_inputs = to_raw_ptr_array(inputs);
auto output_descs = infer_output_attrs(def, raw_inputs);
SmallVector<TensorPtr> outputs(output_descs.size(), {});
for (size_t i = 0; i < outputs.size(); i++) {
outputs[i] =
Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
}
exec(def, inputs, outputs);
return outputs;
}
};
} // namespace mgb::imperative::proxy_graph
......@@ -23,6 +23,7 @@ thread_local std::unique_ptr<ProxyGraphTypeI> ProxyGraphTypeI::sm_instance = {};
} // namespace mgb::imperative::proxy_graph
namespace mgb::imperative::proxy_graph_detail {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible(
......@@ -42,4 +43,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return ret;
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs) {
auto ret =
proxy_graph::ProxyGraphTypeI::inst().apply_on_physical_tensor(def, inputs);
return ret;
}
} // namespace mgb::imperative::proxy_graph_detail
......@@ -17,6 +17,9 @@ namespace mgb {
namespace imperative {
namespace proxy_graph_detail {
// those functions are reimplemented with opr cache
// in ./proxy_graph/mini_graph.h
#if 0
namespace {
SmallVector<Tensor*> to_raw_ptr_array(
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
......@@ -83,12 +86,13 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs;
}
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const
// OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) {
// auto&& graph = ProxyGraph::get_default_graph();
// return graph->infer_output_attrs_fallible(def, inputs);
// }
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs_fallible(def, inputs);
}
#endif
EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
......
......@@ -1009,6 +1009,7 @@ void Split::init_output_static_infer_desc() {
bool Split::infer_shape(
size_t out_idx, TensorShape& dest, const cg::static_infer::InpVal& inp) {
mgb_assert(inp.run_id > 0, "run id should be a positive number");
if (inp.run_id != m_output_shape_version) {
std::vector<size_t> partition;
auto ishp = inp.val.at(0).shape();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册