提交 b622064a 编写于 作者: M Megvii Engine Team

fix(mge/network): add init parameters for some op nodes

GitOrigin-RevId: 6515c961b7745593ecc280b659527527f362c8fa
上级 c0b9e071
......@@ -392,6 +392,14 @@ class MatrixMul(OpNode):
type = "MatrixMul"
opdef = builtin.MatrixMul
@classmethod
def load(cls, opr):
obj = super(MatrixMul, cls).load(opr)
dim1, dim2 = len(opr.inputs[0].shape), len(opr.inputs[1].shape)
obj.params["dimA"] = dim1
obj.params["dimB"] = dim2
return obj
@register_flops(MatrixMul)
def flops_matmul(opnode: MatrixMul, inputs, outputs):
......@@ -431,6 +439,12 @@ class ConvolutionBackwardData(OpNode):
type = "ConvTranspose"
opdef = builtin.ConvolutionBackwardData
@classmethod
def load(cls, opr):
obj = super(ConvolutionBackwardData, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
class DeformableConvForward(OpNode):
type = "DeformableConv"
......
import io
import os
import platform
from contextlib import contextmanager
import numpy as np
import pytest
......@@ -12,6 +13,7 @@ import megengine.functional as F
import megengine.module as M
import megengine.random as rand
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import set_symbolic_shape, use_symbolic_shape
from megengine.core._wrap import Device
from megengine.core.ops import builtin
from megengine.device import (
......@@ -26,6 +28,14 @@ from megengine.utils.comp_graph_tools import GraphInference
from megengine.utils.network import Network as Net
@contextmanager
def override_symbolic_shape(enable: bool):
old = use_symbolic_shape()
set_symbolic_shape(enable)
yield
set_symbolic_shape(old)
def check_pygraph_dump(trace_func, inp_data, expect_results, max_err=None):
orig_model = io.BytesIO()
inp_size = len(inp_data)
......@@ -41,6 +51,16 @@ def check_pygraph_dump(trace_func, inp_data, expect_results, max_err=None):
orig_model.seek(0)
net = Net.load(orig_model)
# make a graph transform
with override_symbolic_shape(False):
old_inps = net.input_vars
new_inps = [
net.make_input_node(shape=inp.shape, dtype=inp.dtype, name=inp.name)
for inp in old_inps
]
net.replace_vars(dict(zip(old_inps, new_inps)))
file = io.BytesIO()
net.dump(file, optimize_for_inference=False)
file.seek(0)
......@@ -207,6 +227,17 @@ def test_convtranspose():
check_pygraph_dump(fwd, [data], [result], 5)
def test_convtranspose_int8():
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, weight):
return F.quantized.conv_transpose2d(inp, weight, dtype=dtype.qint8(scale=1.0))
inp = Tensor(np.random.random((1, 16, 64, 64)), dtype=dtype.qint8(scale=1.0))
weight = Tensor(np.random.random((16, 16, 4, 4)), dtype=dtype.qint8(scale=1.0))
result = fwd(inp, weight)
check_pygraph_dump(fwd, [inp, weight], [result])
@pytest.mark.skip(reason="pytest aborted")
def test_grouplocal():
n = M.LocalConv2d(3, 32, 32, 32, 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册