提交 14d8b709 编写于 作者: M Megvii Engine Team

perf(mge/imperative): add mini graph to partially replace proxy graph

GitOrigin-RevId: 73e2529ba53ccb6c0607f52aee40e69e2c289343
上级 c294b9d1
......@@ -258,6 +258,9 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice =
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
// if (dest->desc.layout.ndim) {
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout()));
// }
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
......@@ -363,7 +366,7 @@ void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
}
inputs.push_back(i->ptr);
}
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
for (size_t i = 0; i < outputs.size(); i ++) {
auto out_ptr = path.outputs[i].lock();
if (out_ptr) {
......
namespace mgb::imperative::proxy_graph {
// a "namespace" struct to simplify friend declaration,
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph
struct ProxyGraph {
struct InputPlaceholder;
struct MiniGraph;
};
} // namespace mgb::imperative::proxy_graph
#include "megbrain/graph/operator_node.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/op_def.h"
#include "./common.h"
#include "./proxy_graph_base.h"
#include <optional>
#include "range/v3/all.hpp"
namespace mgb::imperative::proxy_graph {
using cg::OperatorNodeBase;
template<typename C, typename E>
std::pair<bool, size_t> find_index(const C& container, const E& item) {
auto&& it = std::find(container.begin(), container.end(), item);
return {it != container.end(), it - container.begin()};
}
template <typename T, typename = void> class TensorAdaptor;
template <typename T, typename U>
using enable_if_same_upto_cv_t = std::enable_if_t<std::is_same_v<std::remove_cv_t<T>, std::remove_cv_t<U>>>;
template<typename T>
class TensorAdaptor<T, enable_if_same_upto_cv_t<T, LogicalTensorDesc>> {
T& wrapped;
template <typename U>
using maybe_add_const_t = std::conditional_t<std::is_const_v<T>, const U, U>;
public:
using type = T;
TensorAdaptor(T& desc) : wrapped(desc) {}
TensorAdaptor(T* desc) : wrapped(*desc) {}
DType dtype() {return wrapped.layout.dtype;}
CompNode comp_node() {return wrapped.comp_node;}
maybe_add_const_t<TensorShape>& shape() {return wrapped.layout;}
bool has_value() {return wrapped.value.shape_valid();}
auto& value() {return wrapped.value;}
auto* operator->() {return &wrapped;}
};
template<typename T>
class TensorAdaptor<T, enable_if_same_upto_cv_t<T, Tensor>> {
Tensor& wrapped;
public:
using type = Tensor;
TensorAdaptor(Tensor& tensor) : wrapped(tensor) {}
TensorAdaptor(Tensor* tensor) : wrapped(*tensor) {}
DType dtype() {return wrapped.dtype();}
CompNode comp_node() {return wrapped.comp_node();}
const TensorShape& shape() {return wrapped.shape();}
type* operator->() {return &wrapped;}
};
// deduction guides
template <typename T> TensorAdaptor(T&) -> TensorAdaptor<T, void>;
template <typename T> TensorAdaptor(T*) -> TensorAdaptor<T, void>;
// single opr graph, for static inference and execution
// contains static inference descs
class ProxyGraph::MiniGraph {
protected:
struct InferDepItem {
bool is_input : 1;
size_t idx : 63;
cg::static_infer::DepType type;
};
enum class InferStatus {
UNKOWN,
READY,
FAILED
};
// inference desc and pre-allocated storage for a single var
template <typename T>
struct InferData {
SmallVector<InferDepItem> deps;
thin_function<bool(T&, const cg::static_infer::InpVal&)> infer_func;
// pre-allocated infer states
InferStatus status = InferStatus::UNKOWN;
cg::static_infer::InpVal inp_val;
T dest;
void initialize(OperatorNodeBase* opr, const cg::static_infer::DepVal& dep_val,
const thin_function<bool(T&, const cg::static_infer::InpVal&)>& func) {
mgb_assert(!infer_func);
infer_func = func;
inp_val.val.resize(dep_val.size());
deps.reserve(dep_val.size());
for (auto&& dep : dep_val) {
auto [found, i] = find_index(opr->input(), dep.dest);
if (found) {
deps.push_back({true, i, dep.type});
} else {
auto [found, i] = find_index(opr->output(), dep.dest);
mgb_assert(found);
deps.push_back({false, i, dep.type});
}
}
}
void reset() {
status = InferStatus::UNKOWN;
if constexpr (std::is_same_v<T, TensorShape>) {
dest.ndim = 0;
} else {
static_assert(std::is_same_v<T, DeviceTensorND>);
dest.storage({});
}
}
};
struct OutputData {
InferData<TensorShape> shape_infer;
InferData<DeviceTensorND> value_infer;
};
struct InferSessionBase {
virtual const TensorShape& infer_shape(VarNode*) {mgb_assert(0);}
virtual const TensorShape* infer_shape_fallible(VarNode*) {mgb_assert(0);}
virtual const DeviceTensorND& infer_value(VarNode*) {mgb_assert(0);}
virtual const DeviceTensorND* infer_value_fallible(VarNode*) {mgb_assert(0);}
};
OperatorNodeBase* m_opr = nullptr;
SmallVector<std::unique_ptr<OperatorNodeBase>> opr_ref_keeper;
size_t run_id = 0;
SmallVector<OutputData> output_data;
SmallVector<size_t> input_remap;
SmallVector<size_t> output_remap;
// pre-allocated buffer for converted inputs
SmallVector<std::optional<DeviceTensorND>> input_value_storage;
InferSessionBase* m_sess = nullptr;
template <typename T>
struct InputAdaptor {
T& wrapped;
SmallVector<std::optional<DeviceTensorND>>& value_storage;
InputAdaptor(MiniGraph& owner, T& inputs) : wrapped(inputs), value_storage(owner.input_value_storage) {}
~InputAdaptor() {
for (auto& i : value_storage) {
i.reset();
}
}
const TensorShape* shape(size_t i) {
TensorAdaptor tensor(wrapped[i]);
auto& shape = tensor.shape();
return shape.ndim ? &shape : nullptr;
}
const DeviceTensorND* value(size_t i, bool sync) {
TensorAdaptor tensor(wrapped[i]);
using tensor_t = std::remove_cv_t<typename decltype(tensor)::type>;
if constexpr (std::is_same_v<tensor_t, Tensor>) {
auto& storage = value_storage[i];
if (!storage) {
if (sync) {
return &storage.emplace(tensor->get_value().proxy_to_default_cpu());
} else {
if (auto* hv = tensor->try_get_value()) {
return &storage.emplace(hv->proxy_to_default_cpu());
}
return nullptr;
}
}
} else {
auto& value = tensor.value();
return value.shape_valid() ? &value : nullptr;
}
}
};
public:
template <typename I, typename G>
MiniGraph(G& graph, const OpDef& opdef, const I& inputs) : input_value_storage(inputs.size()) {
mgb_assert(!m_opr);
auto _ = graph.scoped_attach(this);
cg::VarNodeArray vinputs(inputs.size());
for (auto&& [i, t] : ranges::views::enumerate(inputs)) {
auto tensor = TensorAdaptor(t);
opr_ref_keeper.emplace_back(new InputPlaceholder(graph, tensor.dtype(), tensor.comp_node()));
vinputs[i] = opr_ref_keeper.back()->output(0);
}
auto ovars = OpDef::apply_on_var_node(opdef, vinputs);
mgb_assert(m_opr);
output_data.resize(m_opr->output().size());
for (auto* v : ovars) {
mgb_assert(v->owner_opr() == m_opr);
}
m_opr->init_output_static_infer_desc();
// fix permuted input
input_remap.reserve(m_opr->input().size());
for (auto* v : m_opr->input()) {
auto [found, i] = find_index(vinputs, v);
mgb_assert(found);
input_remap.push_back(i);
}
auto fix_dep_idx = [&](SmallVector<InferDepItem>& deps) {
for (auto& dep : deps) {
if (dep.is_input) {
dep.idx = input_remap[dep.idx];
}
}
};
for (auto& data : output_data) {
fix_dep_idx(data.shape_infer.deps);
fix_dep_idx(data.value_infer.deps);
}
// fix permuted output
output_remap.reserve(ovars.size());
for (auto* v : ovars) {
auto [found, i] = find_index(m_opr->output(), v);
mgb_assert(found);
output_remap.push_back(i);
}
}
// methods for containing graph
OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) {
mgb_assert(!m_opr);
m_opr = opr_uniqp.get();
mgb_assert(opr_ref_keeper.back()->owner_graph() == m_opr->owner_graph());
mgb_assert(!m_opr->inserted_in_graph());
opr_ref_keeper.push_back(std::move(opr_uniqp));
m_opr->set_inserted_in_graph();
m_opr->init_output_comp_node();
m_opr->init_output_dtype();
return m_opr;
}
void register_shape_infer(VarNode* varnode, const cg::static_infer::ShapeInferDesc& desc) {
auto [found, i] = find_index(m_opr->output(), varnode);
mgb_assert(found);
output_data[i].shape_infer.initialize(m_opr, desc.deps, desc.infer_func);
}
void register_value_infer(VarNode* varnode, const cg::static_infer::ValueInferDesc& desc) {
auto [found, i] = find_index(m_opr->output(), varnode);
mgb_assert(found);
output_data[i].value_infer.initialize(m_opr, desc.deps, desc.infer_func);
}
const TensorShape& infer_shape(VarNode* var) {
return m_sess->infer_shape(var);
}
const DeviceTensorND& infer_value(VarNode* var) {
return m_sess->infer_value(var);
}
OperatorNodeBase* opr() {
return m_opr;
}
// inference routine template for type of input
template<typename I>
class InferSession : protected InferSessionBase {
MiniGraph& owner;
SmallVector<OutputData>& output_data;
InputAdaptor<I> inputs;
template<typename T>
const T* infer(InferData<T>& target, bool sync) {
bool ret;
if (target.status != InferStatus::UNKOWN) {
ret = target.status == InferStatus::READY;
} else {
ret = target.infer_func && do_infer(target, sync);
target.status = ret ? InferStatus::READY : InferStatus::FAILED;
}
return ret ? &target.dest : nullptr;
}
template<typename T>
bool do_infer(InferData<T>& target, bool sync) {
for (size_t i = 0; i < target.deps.size(); ++i) {
target.inp_val.run_id = owner.run_id;
auto& dep = target.deps[i];
if (dep.is_input) {
if (dep.type == cg::static_infer::DepType::SHAPE) {
if (auto* val = inputs.shape(dep.idx)) {
target.inp_val.val[i].m_shape = val;
} else return false;
} else {
if (auto* val = inputs.value(dep.idx, sync)) {
target.inp_val.val[i].m_value = val;
} else return false;
}
} else {
if (dep.type == cg::static_infer::DepType::SHAPE) {
if (auto* val = infer(output_data[dep.idx].shape_infer, sync)) {
target.inp_val.val[i].m_shape = val;
} else return false;
} else {
if (auto* val = infer(output_data[dep.idx].value_infer, sync)) {
target.inp_val.val[i].m_value = val;
} else return false;
}
}
}
return target.infer_func(target.dest, target.inp_val);
}
// methods for owner mini graph
// corresponding methods of containing ComputingGraph will be redirected here
const TensorShape& infer_shape(VarNode* var) override {
mgb_assert(owner.m_opr);
auto [found, i] = find_index(owner.m_opr->input(), var);
mgb_assert(found);
i = owner.input_remap[i];
auto* shape = inputs.shape(i);
mgb_assert(shape);
return *shape;
}
const DeviceTensorND& infer_value(VarNode* var) override {
mgb_assert(owner.m_opr);
auto [found, i] = find_index(owner.m_opr->input(), var);
mgb_assert(found);
i = owner.input_remap[i];
auto* value = inputs.value(i, false);
mgb_assert(value);
return *value;
}
public:
InferSession(MiniGraph& mgraph, I& inputs_)
: owner(mgraph), output_data(mgraph.output_data), inputs(mgraph, inputs_) {
mgraph.run_id++;
mgb_assert(!owner.m_sess);
owner.m_sess = this;
}
~InferSession() {
owner.m_sess = nullptr;
for (auto& i : output_data) {
i.shape_infer.reset();
i.value_infer.reset();
}
}
const TensorShape* infer_shape(size_t i, bool sync) {
i = owner.output_remap[i];
return infer(output_data[i].shape_infer, sync);
}
const DeviceTensorND* infer_value(size_t i, bool sync) {
i = owner.output_remap[i];
return infer(output_data[i].shape_infer, sync);
}
};
template <typename T>
InferSession<T> infer_session(T& inputs) {return InferSession(*this, inputs);}
size_t output_size() {
return output_remap.size();
}
VarNode* output_var(size_t i) {
i = output_remap[i];
return m_opr->output(i);
}
};
class CompNodeTracker {
static constexpr size_t bucket_size = 100;
static constexpr size_t bucket_count = 10;
CompNode comp_node;
std::array<std::unique_ptr<CompNode::Event>, bucket_count> events;
size_t free_slots = bucket_size;
size_t head = 0; // events[head] is not recorded
size_t tail = 0; // events[tail] is not finished
void rotate() {
while (tail < head && events[tail % bucket_count]->finished()) {
++tail;
}
auto& ev = events[head % bucket_count];
if (head == tail + bucket_count) {
// do not wait if head == tail
ev->host_wait();
++tail;
}
ev->record();
++head;
free_slots = bucket_size;
}
public:
CompNodeTracker(CompNode cn) : comp_node(cn) {
for (auto& e : events) {
e = cn.create_event();
}
}
size_t add_opr() {
if (!free_slots) rotate();
--free_slots;
return head;
}
size_t progress() {
return tail;
}
};
class ExecMiniGraph : public ProxyGraph::MiniGraph {
union BusyListItem {
size_t finish_time;
OperatorNodeBase* opr;
};
SmallVector<CompNodeTracker*> comp_node_trackers;
std::deque<BusyListItem> busy_oprs;
SmallVector<OperatorNodeBase*> idle_oprs;
OperatorNodeBase* acquire_opr() {
mgb_assert(!m_opr);
if (!idle_oprs.empty()) {
m_opr = idle_oprs.back();
idle_oprs.pop_back();
return m_opr;
}
mgb_assert(busy_oprs.size() > comp_node_trackers.size());
bool can_pop = true;
for (auto [item, tracker] : ranges::views::zip(busy_oprs, comp_node_trackers)) {
if (item.finish_time >= tracker->progress()) {
can_pop = false;
break;
}
}
if (can_pop) {
for (auto _ : comp_node_trackers) {
busy_oprs.pop_front();
}
m_opr = busy_oprs.front().opr;
busy_oprs.pop_front();
return m_opr;
}
}
template <bool in_use>
void release_opr() {
if constexpr (in_use) {
for (auto tracker : comp_node_trackers) {
tracker->add_opr();
}
}
}
};
class ProxyGraphTypeI : public ProxyGraphBase {
class StaticInferManager : public StaticInferManagerBase {
ProxyGraph::MiniGraph* target = nullptr;
friend class ProxyGraphTypeI;
public:
void register_shape_infer(VarNode* var, const cg::static_infer::ShapeInferDesc& desc) override {
target->register_shape_infer(var, desc);
};
void register_value_infer(VarNode* var, const cg::static_infer::ValueInferDesc& desc) override {
target->register_value_infer(var, desc);
};
cg::static_infer::InferType get_infer_type(VarNode*) override {
return {cg::static_infer::InferType::MISSING_INP, cg::static_infer::InferType::MISSING_INP};
}
// some poorly written inference func would call infer_{shape,value}
const TensorShape& infer_shape(VarNode* var) override {
return target->infer_shape(var);
}
const DeviceTensorND& infer_value(VarNode* var) override {
return target->infer_value(var);
}
};
ProxyGraph::MiniGraph* target = nullptr;
StaticInferManager m_static_infer_manager;
std::unordered_map<size_t, ProxyGraph::MiniGraph> m_mini_graph_cache;
size_t opr_count = 0;
static thread_local std::unique_ptr<ProxyGraphTypeI> sm_instance;
friend class ProxyGraph::MiniGraph;
size_t nr_oprs_in_graph() const override {
return opr_count;
}
size_t next_node_id() override {
return opr_count;
}
std::shared_ptr<void> on_comp_node_finalize() override {
sm_instance.reset();
return {};
}
cg::static_infer::StaticInferManager& static_infer_manager() override {
return m_static_infer_manager;
}
void attach(ProxyGraph::MiniGraph* target_) {
target = target_;
m_static_infer_manager.target = target_;
}
struct AttachGuard {
ProxyGraphTypeI* owner = nullptr;
ProxyGraph::MiniGraph* target = nullptr;
AttachGuard(ProxyGraphTypeI* owner_ = nullptr, ProxyGraph::MiniGraph* target_ = nullptr)
: owner(owner_), target(target_) {}
AttachGuard(AttachGuard&) = delete;
AttachGuard& operator=(AttachGuard&) = delete;
AttachGuard(AttachGuard&& rhs) : owner(rhs.owner), target(rhs.target) {rhs.owner = nullptr;}
AttachGuard& operator=(AttachGuard&& rhs) = delete;
~AttachGuard() {if (owner) owner->attach(target);}
};
[[nodiscard]]
AttachGuard scoped_attach(ProxyGraph::MiniGraph* target_) {
attach(target_);
return attach_guard();
}
[[nodiscard]]
AttachGuard attach_guard(ProxyGraph::MiniGraph* target_ = nullptr) {
return {this, target_};
}
public:
OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override {
return target->insert_opr(std::move(opr_uniqp));
}
static ProxyGraphTypeI& inst() {
if (!sm_instance) {
sm_instance.reset(new ProxyGraphTypeI);
}
return *sm_instance;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
size_t buf_size = 2 * inputs.size() + 1;
size_t buf[buf_size];
size_t pos = 0;
buf[pos++] = def.hash();
for (auto&& desc : inputs) {
buf[pos++] = mgb::hash(desc.layout.dtype.handle());
buf[pos++] = mgb::hash(desc.comp_node);
}
mgb_assert(pos == buf_size);
auto key = XXHash{}.update(buf, buf_size*sizeof(size_t)).digest();
auto it = m_mini_graph_cache.find(key);
if (it == m_mini_graph_cache.end()) {
auto&& result = m_mini_graph_cache.emplace(
std::piecewise_construct,
std::make_tuple(key),
std::forward_as_tuple(*this, def, inputs));
mgb_assert(result.second);
it = result.first;
}
auto& minigraph = it->second;
auto _ = scoped_attach(&minigraph);
auto sess = minigraph.infer_session(inputs);
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
auto& [descs, noerr] = ret;
descs.reserve(minigraph.output_size());
for (size_t i = 0; i < minigraph.output_size(); ++i) {
descs.emplace_back();
auto& desc = descs.back();
desc.layout.dtype = minigraph.output_var(i)->dtype();
desc.comp_node = minigraph.output_var(i)->comp_node();
if (auto* shape = sess.infer_shape(i, false)) {
desc.layout.init_contiguous_stride(*shape);
} else {
noerr = false;
}
}
return ret;
}
};
} // namespace mgb::imperative::proxy_graph
#include "./mini_graph.h"
// #include "../proxy_graph.h"
namespace mgb::imperative::proxy_graph {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder);
thread_local std::unique_ptr<ProxyGraphTypeI> ProxyGraphTypeI::sm_instance = {};
} // namespace mgb::imperative::proxy_graph
namespace mgb::imperative::proxy_graph_detail {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible(def, inputs);
// auto ref = ProxyGraph::get_default_graph()->infer_output_attrs_fallible(def, inputs);
// auto& [a, _1] = ret;
// auto& [b, _2] = ref;
// if (a.size() != b.size()) mgb_trap();
// for (size_t i = 0; i < a.size(); ++i) {
// if (a[i].layout.dtype != b[i].layout.dtype) mgb_trap();
// if (a[i].comp_node != b[i].comp_node) mgb_trap();
// if (!a[i].layout.eq_shape(b[i].layout)) mgb_trap();
// }
return ret;
}
} // namespace mgb::imperative::proxy_graph_detail
#include "megbrain/graph/cg.h"
namespace mgb::imperative::proxy_graph {
using cg::VarNode;
struct ExecEnvBase : cg::GraphExecutable::ExecEnv {
void dispatch_on_comp_node(CompNode, Task&& task) override {
task();
}
void dispatch_on_comp_node_with_mask(CompNode, Task&&, cg::ExecutionMask*) override {mgb_assert(0);}
void pause_exec() override {mgb_assert(0);}
void resume_exec() override {mgb_assert(0);}
};
struct StaticInferManagerBase : cg::static_infer::StaticInferManager {
protected:
void register_shape_infer(VarNode*, const cg::static_infer::ShapeInferDesc&) override {mgb_assert(0);};
void register_value_infer(VarNode*, const cg::static_infer::ValueInferDesc&) override {mgb_assert(0);};
cg::static_infer::InferType get_infer_type(VarNode*) override {mgb_assert(0);};
const TensorShape& infer_shape(VarNode*) override {mgb_assert(0);}
const TensorShape* infer_shape_fallible(VarNode*) override {mgb_assert(0);}
const DeviceTensorND& infer_value(VarNode*) override {mgb_assert(0);}
const DeviceTensorND* infer_value_fallible(VarNode*) override {mgb_assert(0);}
cg::static_infer::DepVal get_rt_static_source_deps(const cg::static_infer::DepElement&) override {mgb_assert(0);}
};
struct SeqCompNodeOptimizerBase : cg::SeqCompNodeOptimizer {
protected:
void register_stream_var(VarNode*, StreamPropType) override {}
void register_propagate_function(VarNode*, PropFunction) override {}
StreamPropType stream_prop_type(VarNode*) override {mgb_assert(0);}
};
struct ProxyGraphBase : cg::ComputingGraph {
private:
VarReceiverInfo m_var_receiver_info;
SeqCompNodeOptimizerBase m_seq_comp_node_optimizer;
StaticInferManagerBase m_static_infer_manager;
protected:
MemPool<VarNode> m_var_node_pool;
ProxyGraphBase() {
options().imperative_proxy_graph = true;
options().no_force_inplace = true;
options().log_level = 0;
m_var_receiver_info.dev_value = 1;
m_var_receiver_info.allow_empty_value = 1;
}
void* alloc_varnode_storage() override {
return m_var_node_pool.alloc_raw();
}
void free_varnode_storage(void* ptr) override {
m_var_node_pool.free_raw(ptr);
}
const VarReceiverInfo& var_receiver_in_current_comp_seq(const VarNode *var) const override {
return m_var_receiver_info;
}
cg::static_infer::StaticInferManager& static_infer_manager() override {
return m_static_infer_manager;
}
cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
return m_seq_comp_node_optimizer;
}
std::shared_ptr<void> on_comp_node_finalize() override {
return {};
}
std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec&) override {mgb_assert(0);}
SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(const SmallVector<OutputSpec>&) override {mgb_assert(0);}
cg::AsyncExecutable* current_comp_seq() override {mgb_assert(0);}
std::string get_mem_allocation_info() const override {mgb_assert(0);}
VarNode* find_var_by_id(size_t) const override {mgb_assert(0);}
void share_device_memory_with(ComputingGraph&) override {mgb_assert(0);}
void set_device_memory_allocator(std::shared_ptr<cg::DeviceMemoryAllocator>) override {mgb_assert(0);}
size_t get_device_memory_size(CompNode) override {mgb_assert(0);}
size_t clear_device_memory() override {mgb_assert(0);}
void set_as_subgraph(ComputingGraph&) override {mgb_assert(0);}
void record_async_error(std::unique_ptr<MegBrainError>) override {mgb_assert(0);}
};
MGB_DEFINE_OPR_CLASS(
ProxyGraph::InputPlaceholder,
cg::OperatorNodeBase) // {
void on_output_comp_node_stream_changed() override {mgb_assert(0);}
void init_output_comp_node() override {}
void init_output_format() override {}
void init_output_dtype() override {}
void init_output_static_infer_desc() override {}
void init_output_mem_plan(bool) override {mgb_assert(0);}
void do_execute(ExecEnv&) override {mgb_assert(0);}
public:
InputPlaceholder(cg::ComputingGraph& graph)
: Super(&graph, {}, "placeholder", {}) {
add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
// never dedup
add_equivalence_component<ScalarHash<void*>>(this);
}
InputPlaceholder(cg::ComputingGraph& graph, DType dtype, CompNode cn)
: InputPlaceholder(graph) {
output(0)->dtype(dtype).comp_node(cn);
}
};
using InputPlaceholder = ProxyGraph::InputPlaceholder;
} // namespace mgb::imperative::proxy_graph
......@@ -80,11 +80,11 @@ apply_on_physical_tensor(const OpDef& def,
return outputs;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs_fallible(def, inputs);
}
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) {
// auto&& graph = ProxyGraph::get_default_graph();
// return graph->infer_output_attrs_fallible(def, inputs);
// }
namespace {
......
......@@ -89,10 +89,18 @@ public:
return m_blob->comp_node();
}
DType dtype() const {
return m_layout.dtype;
}
TensorLayout layout() const {
return m_layout;
}
const TensorShape& shape() const {
return m_layout;
}
DeviceTensorND dev_tensor();
static TensorPtr make_scalar(DTypeScalar value, CompNode cn);
......
......@@ -16,7 +16,10 @@
namespace mgb {
namespace imperative {
class ProxyGraph;
class ProxyGraph;
namespace proxy_graph {
class ProxyGraph;
} // namespace proxy_graph
} // namespace imperative
namespace cg {
......@@ -56,6 +59,7 @@ namespace static_infer {
friend class StaticInferManagerImpl;
friend class imperative::ProxyGraph;
friend class imperative::proxy_graph::ProxyGraph;
public:
/*!
......@@ -342,4 +346,3 @@ using StaticInferInpVal = static_infer::InpVal;
} // mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -23,7 +23,10 @@
namespace mgb {
namespace imperative {
class ProxyGraph;
class ProxyGraph;
namespace proxy_graph {
class ProxyGraph;
}
} // namespace imperative
namespace cg {
......@@ -587,6 +590,7 @@ class VarNode final: public GraphNodeBase {
friend class EagerEvalManager;
friend class MemAllocPlan;
friend class imperative::ProxyGraph;
friend class imperative::proxy_graph::ProxyGraph;
};
enum class VarNode::Flag : uint32_t {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册