提交 ac3408bf 编写于 作者: M Megvii Engine Team

chore(mge): add VarNode.value

GitOrigin-RevId: 1dc0d0c711b3bc3e7bfb04fa933839928d579a8c
上级 0537cb74
......@@ -95,6 +95,10 @@ class VarNode(TensorBase):
def shape(self):
return self._node.shape
@property
def value(self):
return self._node.value
class OpNode:
def __init__(self, node: _imperative_rt.OperatorNode):
......
......@@ -399,7 +399,7 @@ class LazyEvalTensor(RawTensor):
return self.__varnode.shape
def numpy(self):
raise RuntimeError("cannot read value during symbolic tracing")
return self.__varnode.value
def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing")
......
......@@ -58,6 +58,19 @@ void init_graph_rt(py::module m) {
return nullptr;
}
return mgr.infer_shape_fallible(v);
})
.def_property_readonly("value", [](cg::VarNode* v) -> py::object {
auto&& mgr = v->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(v);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
return py::none();
}
auto* val = mgr.infer_value_fallible(v);
if (!val) {
return py::none();
}
return py::cast(*val).attr("numpy")();
});
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册