diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 722b582b3e9ea2cdd2d361653aeefdd929cc9773..9a55ae52c14573692f9c80fec33c2ca9a84aa695 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -16,6 +16,7 @@ import megengine import megengine.core.tensor.megbrain_graph as G import megengine.module as M from megengine import cgtools, tensor +from megengine.core._trace_option import set_tensor_shape from megengine.core.ops import builtin as ops from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.core import apply @@ -274,3 +275,15 @@ def test_optimize_for_inference(): res = G.load_comp_graph_from_file(out) computing_input = res.output_vars_list[0].owner.inputs[0] assert computing_input.dtype == np.float16 + + +def test_trace_cvt_bool(): + set_tensor_shape(True) + x = tensor([0], dtype=np.int32) + + @trace(symbolic=True) + def f(x): + return x.shape[0] == 0 + + for i in range(3): + np.testing.assert_equal(f(x).numpy()[0], False) diff --git a/src/core/impl/dtype.cpp b/src/core/impl/dtype.cpp index a91a9f9f7f01fe900658c87a170d2380c74d34b3..6c8e9c9b9a27cd05b4be79c08e41bcfdada951b4 100644 --- a/src/core/impl/dtype.cpp +++ b/src/core/impl/dtype.cpp @@ -136,6 +136,7 @@ void mgb::static_cast_dtype(T* dest, DType src_type, const void* storage, nr_elem, src_type); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb #define cb(_name, _bits) \ case DTypeTrait::enumv: \