diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index b44897e9905c3b171efbc4db63e28652b071ac74..0aae12931eae9891a1a538570a77a641ed242d51 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -392,6 +392,14 @@ class MatrixMul(OpNode): type = "MatrixMul" opdef = builtin.MatrixMul + @classmethod + def load(cls, opr): + obj = super(MatrixMul, cls).load(opr) + dim1, dim2 = len(opr.inputs[0].shape), len(opr.inputs[1].shape) + obj.params["dimA"] = dim1 + obj.params["dimB"] = dim2 + return obj + @register_flops(MatrixMul) def flops_matmul(opnode: MatrixMul, inputs, outputs): @@ -431,6 +439,12 @@ class ConvolutionBackwardData(OpNode): type = "ConvTranspose" opdef = builtin.ConvolutionBackwardData + @classmethod + def load(cls, opr): + obj = super(ConvolutionBackwardData, cls).load(opr) + obj.params["dtype"] = opr.outputs[0].dtype + return obj + class DeformableConvForward(OpNode): type = "DeformableConv" diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index fd2361e465217556b1be6fabbad7295c514ce8e2..57c5fd807e0ff8f8f3f0ea851b9048b250501406 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -1,6 +1,7 @@ import io import os import platform +from contextlib import contextmanager import numpy as np import pytest @@ -12,6 +13,7 @@ import megengine.functional as F import megengine.module as M import megengine.random as rand from megengine.core._imperative_rt.core2 import apply +from megengine.core._trace_option import set_symbolic_shape, use_symbolic_shape from megengine.core._wrap import Device from megengine.core.ops import builtin from megengine.device import ( @@ -26,6 +28,14 @@ from megengine.utils.comp_graph_tools import GraphInference from megengine.utils.network import Network as Net +@contextmanager +def override_symbolic_shape(enable: bool): + old = use_symbolic_shape() + set_symbolic_shape(enable) + yield + set_symbolic_shape(old) + + def check_pygraph_dump(trace_func, inp_data, expect_results, max_err=None): orig_model = io.BytesIO() inp_size = len(inp_data) @@ -41,6 +51,16 @@ def check_pygraph_dump(trace_func, inp_data, expect_results, max_err=None): orig_model.seek(0) net = Net.load(orig_model) + + # make a graph transform + with override_symbolic_shape(False): + old_inps = net.input_vars + new_inps = [ + net.make_input_node(shape=inp.shape, dtype=inp.dtype, name=inp.name) + for inp in old_inps + ] + net.replace_vars(dict(zip(old_inps, new_inps))) + file = io.BytesIO() net.dump(file, optimize_for_inference=False) file.seek(0) @@ -207,6 +227,17 @@ def test_convtranspose(): check_pygraph_dump(fwd, [data], [result], 5) +def test_convtranspose_int8(): + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, weight): + return F.quantized.conv_transpose2d(inp, weight, dtype=dtype.qint8(scale=1.0)) + + inp = Tensor(np.random.random((1, 16, 64, 64)), dtype=dtype.qint8(scale=1.0)) + weight = Tensor(np.random.random((16, 16, 4, 4)), dtype=dtype.qint8(scale=1.0)) + result = fwd(inp, weight) + check_pygraph_dump(fwd, [inp, weight], [result]) + + @pytest.mark.skip(reason="pytest aborted") def test_grouplocal(): n = M.LocalConv2d(3, 32, 32, 32, 3)