提交 20719502 编写于 作者: M Megvii Engine Team

fix(imperative): fix matmul deduce dtype

GitOrigin-RevId: 24f4e1f9fc1fb58b3443d04c13e95e7493d119d1
上级 d733f429
......@@ -340,3 +340,19 @@ def test_conv_transpose2d():
test_func(2, 4, 3, 1, 8, 1, 1, 1, 1, 0, 0, 1, 1, 1, False)
test_func(4, 4, 16, 16, 8, 3, 3, 1, 1, 1, 1, 1, 1, 1, False)
test_func(32, 64, 36, 28, 16, 3, 2, 1, 3, 1, 0, 1, 1, 1, False)
def test_matmul():
inp_scale = np.float32(np.random.rand())
weight_scale = np.float32(np.random.rand())
inp_dtype = dtype.qint8(inp_scale)
weight_dtype = dtype.qint8(weight_scale)
inp_data = np.random.random((3, 12))
weight_data = np.random.random((5, 12))
inp_int8 = mge.tensor(dtype.convert_to_qint8(inp_data, inp_dtype))
weight_int8 = mge.tensor(dtype.convert_to_qint8(weight_data, weight_dtype))
res = F.matmul(inp_int8, weight_int8, transpose_b=True)
res_scale = dtype.get_scale(res.dtype)
np.testing.assert_allclose(inp_scale * weight_scale, res_scale)
......@@ -104,7 +104,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
DnnOprHelper<megdnn::MatrixMul> dnn_opr(matmul.param());
dnn_opr.opr().deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.opr().deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);
if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false};
......@@ -157,7 +157,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}
DType dst_dtype;
dnn_opr.op()->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.op()->deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);
// only matters when layout1 has dim 2
if (matmul.transposeA)
......@@ -335,7 +335,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
DType dst_dtype;
DnnOprHelper<megdnn::MatrixMul> dnn_opr(matmul.param());
dnn_opr.opr().deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.opr().deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);
if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false};
......@@ -378,7 +378,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn, matmul.param(), matmul.policy());
DType dst_dtype;
dnn_opr.op()->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.op()->deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);
TensorShape tshp, batch_shp;
size_t j = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册