提交 1fe8a212 编写于 作者: M Megvii Engine Team

fix(mge): fix sublinear memory in jit.trace

GitOrigin-RevId: 190a330a8ce2016d24faed0246cc0c294c755946
上级 2df1ab96
......@@ -210,6 +210,7 @@ class trace:
info.external = True
info.device = x.device
info.dtype = x.dtype
info.shape = x.shape
if self._capture_as_const:
info.bound_data = x
......@@ -338,7 +339,7 @@ class trace:
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
info = self._tinfo[h]
opnode = info.data_setter = G.InputNode(
device=info.device, dtype=info.dtype, graph=graph
device=info.device, dtype=info.dtype, shape=info.shape, graph=graph
)
need_reset_nodes.append(opnode)
info.varnode = opnode.outputs[0]
......@@ -355,7 +356,11 @@ class trace:
info.varnode = graph.make_const(info.bound_data._dev_tensor())
else:
opnode = info.data_setter = G.InputNode(
*links, device=info.device, dtype=info.dtype, graph=graph
*links,
device=info.device,
dtype=info.dtype,
shape=info.shape,
graph=graph,
)
need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs
......
../../../src/core/impl
\ No newline at end of file
......@@ -10,6 +10,7 @@
*/
#include "megbrain/imperative/opr_utility.h"
#include "./mgb_core_impl/graph/cg_impl.h"
// FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp
namespace {
......@@ -64,14 +65,18 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,
void InputCallback::init_output_static_infer_desc() {
if (m_output_shape.ndim) {
// Write this shape to static infer manager. The effect is
// that infer_shape_fallible() will return a non-empty shape
// while get_infer_type() remains NO_DESC. Most places check
// infer type before relying on inferred shape so things
// won't break. Memory optimizer however, deliberately omits
// infer type check so it will be able to use this shape for hint.
using namespace cg::static_infer;
auto &&mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape &dest, const InpVal &) {
dest = m_output_shape;
return true;
};
mgr.register_shape_infer(output(0),
{SourceType::CONSTANT, {}, infer_shape});
auto* var = output(0);
var->shape(m_output_shape);
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
auto* handle = mgr.get_tag_handler_for_shape(var);
handle->sync_from_var();
}
}
......@@ -86,9 +91,6 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
void InputCallback::scn_do_execute() {
auto dev_tensor = m_callback();
if (m_output_shape.ndim) {
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape));
}
output(0)->reset_dev_tensor_from_tensor(dev_tensor);
}
......
......@@ -99,7 +99,9 @@ MemoryOptimizerHelper::split_into_cn2oprseq(const OprNodeArray& oprseq,
auto&& infer_mgr = m_owner_graph->static_infer_manager();
for (auto j : i->output()) {
if (!j->contain_flag(BAD_VAR_FLAG) && is_static_var_shape(j)) {
if (!j->contain_flag(BAD_VAR_FLAG)) {
// omit infer type check
// inferred shape will be used as-is
if (auto shape = infer_mgr.infer_shape_fallible(j)) {
have_static_shape_out = true;
m_var_memsize[j] = j->dtype().size(shape->total_nr_elems());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册