提交 dd39265e 编写于 作者: M Megvii Engine Team

fix(mgb/dtype): enable TypeCvt for bool when trace(symbolic=True)

GitOrigin-RevId: 4e0fc63369b623e6e9e9eca396ec03f87b56452f
上级 5c37a64e
......@@ -16,6 +16,7 @@ import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.module as M
from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape
from megengine.core.ops import builtin as ops
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.core import apply
......@@ -274,3 +275,15 @@ def test_optimize_for_inference():
res = G.load_comp_graph_from_file(out)
computing_input = res.output_vars_list[0].owner.inputs[0]
assert computing_input.dtype == np.float16
def test_trace_cvt_bool():
set_tensor_shape(True)
x = tensor([0], dtype=np.int32)
@trace(symbolic=True)
def f(x):
return x.shape[0] == 0
for i in range(3):
np.testing.assert_equal(f(x).numpy()[0], False)
......@@ -136,6 +136,7 @@ void mgb::static_cast_dtype(T* dest, DType src_type, const void* storage,
nr_elem, src_type);
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
#define cb(_name, _bits) \
case DTypeTrait<dtype::_name##_bits>::enumv: \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册