提交 19043e7b 编写于 作者: W wangzhuo325

support adding bias in matmul

上级 00df47a0
......@@ -271,7 +271,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
"bias": bias_name,
})
if out_dtype == "float16":
if out_dtype == "float16" and (bias_value == None or bias_value.dtype == "float16"):
result_matmul = cast.cast(result_matmul, out_dtype)
def matmul_reshape(shape, result_matmul, *indices):
......@@ -288,10 +288,10 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
N = len(output_shape)
# reduce axis
if output_format == "zN":
bias_indices = indices[:(N - 4)] + indices[(N - 4):(N - 3)] + (0, 0) + indices[(N - 1):]
bias_indices = indices[N - 4] * cce.BLOCK_OUT + indices[N - 1]
elif output_format == "zZ":
bias_indices = indices[:(N - 4)] + (0,) + indices[(N - 3):(N - 2)] + (0,) + indices[(N - 1):]
return result(*indices) + bias(*bias_indices)
bias_indices = indices[N - 3] * cce.BLOCK_OUT + indices[N - 1]
return result(*indices) + bias(bias_indices)
if bias == 1:
if out_format == "zN":
out = akg.tvm.compute(output_shape_zN, lambda *indices: bias_compute(output_shape_zN, result, bias_value, out_format, *indices),
......@@ -299,6 +299,8 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
elif out_format == "zZ":
out = akg.tvm.compute(output_shape_zZ, lambda *indices: bias_compute(output_shape_zZ, result, bias_value, out_format, *indices),
name="output")
if out_dtype == "float16" and bias_value.dtype == "float32":
out = cast.cast(out, out_dtype)
else:
out = result
......
......@@ -171,10 +171,21 @@ class HoistL0Write : public IRMutator {
for (const auto &arg : op->args) {
args.push_back(this->Mutate(arg));
}
return Provide::make(op->func, op->value_index, op->value, args);
auto value = this->Mutate(op->value);
return Provide::make(op->func, op->value_index, value, args);
}
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Call *op, const Expr &e) final {
if (mutate_write_) {
Array<Expr> args;
for (const auto &arg : op->args) {
args.push_back(this->Mutate(arg));
}
return Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index);
}
return IRMutator::Mutate_(op, e);
}
bool found_{false};
bool mutate_{false};
......
......@@ -121,7 +121,7 @@ def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_
def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"):
dtype="float16", bias_dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"):
shape_x, shape_y = get_shapes(batch_tuple, M, K, N, trans_data, trans_weight)
matrix_a = random_gaussian(shape_x, miu=0.1, sigma=0.01).astype(dtype)
matrix_b = random_gaussian(shape_y, miu=0.1, sigma=0.01).astype(dtype)
......@@ -137,13 +137,19 @@ def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
if dtype == "float16":
out.astype(np.float16)
bias_shape = batch_tuple + (N // cce.BLOCK_OUT, 1, 1, cce.BLOCK_OUT)
if output_format == "zZ":
bias_shape = batch_tuple + (1, N // cce.BLOCK_OUT, 1, cce.BLOCK_OUT)
bias_data = np.full(bias_shape, np.nan, out_dtype)
if bias == 1:
bias_data = random_gaussian(bias_shape, miu=0.5, sigma=0.01).astype(out_dtype)
out = out + bias_data
bias_shape = (N,)
bias_data = np.full(bias_shape, np.nan, bias_dtype)
if bias != 0:
bias_data = random_gaussian(bias_shape, miu=0.5, sigma=0.01).astype(bias_dtype)
bias_reshape = (N // cce.BLOCK_OUT, 1, 1, cce.BLOCK_OUT)
if output_format == "zZ":
bias_reshape = (1, N // cce.BLOCK_OUT, 1, cce.BLOCK_OUT)
bias_data_reshaped = bias_data.reshape(bias_reshape)
if bias_dtype != out_dtype:
out = out.astype(np.float32) + bias_data_reshaped.astype(np.float32)
out = out.astype(out_dtype)
else:
out = out + bias_data_reshaped
shape_x = ()
shape_y = ()
......@@ -185,14 +191,14 @@ def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
return fractal_a, fractal_b, out, bias_data
def matmul_data(batch_tuple, M, K, N, dtype, out_dtype, bias, adj_x, adj_y, left_format=None, right_format=None, output_format=None, debug_logging=False):
def matmul_data(batch_tuple, M, K, N, dtype, bias_dtype, out_dtype, bias, adj_x, adj_y, left_format=None, right_format=None, output_format=None, debug_logging=False):
m_x = ()
m_y = ()
bench_mark = ()
bias_data = ()
logging.debug("gen data start!")
a = datetime.now()
m_x, m_y, bench_mark, bias_data = genData(batch_tuple, M, K, N, adj_x, adj_y, dtype, out_dtype, bias, left_format, right_format, output_format)
m_x, m_y, bench_mark, bias_data = genData(batch_tuple, M, K, N, adj_x, adj_y, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format)
b = datetime.now()
logging.debug((b - a).seconds)
logging.debug("gen data end!")
......@@ -295,17 +301,13 @@ def get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias, left_format="
output_shape = batch_tuple + (m // cce.BLOCK_OUT, 1, n % cce.BLOCK_IN, cce.BLOCK_OUT)
if bias == 1:
if out_format == "zN":
bias_shape_nc1hwc0 = batch_tuple + (n // cce.BLOCK_OUT, 1, 1, cce.BLOCK_OUT)
elif out_format == "zZ":
bias_shape_nc1hwc0 = batch_tuple + (1, n // cce.BLOCK_OUT, 1, cce.BLOCK_OUT)
bias_shape_nc1hwc0 = (n,)
else:
bias_shape_nc1hwc0 = None
return shape_xx, shape_yy, bias_shape_nc1hwc0, output_shape, k
def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs):
def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs):
'''
There are four types of fractal format in Davinci core: zZ, zN, nZ, nN
general matmul format
......@@ -323,9 +325,9 @@ def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format
n = (n + 15) // 16 * 16
k = (k + 15) // 16 * 16
shape_xx, shape_yy, bias_shape, out_shape, k = get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias, left_format, right_format, out_format)
mod = matmul_compile(shape_x, shape_y, bias, left_format, right_format, out_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs)
mod = matmul_compile(shape_x, shape_y, bias, left_format, right_format, out_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs)
# Generate data
m_x, m_y, bench_mark, bias_data = matmul_data(batch_tuple, m, k, n, dtype, out_dtype, bias, adj_x, adj_y, left_format, right_format, out_format)
m_x, m_y, bench_mark, bias_data = matmul_data(batch_tuple, m, k, n, dtype, bias_dtype, out_dtype, bias, adj_x, adj_y, left_format, right_format, out_format)
# mod launch
output = np.full(out_shape, np.nan, out_dtype)
......@@ -341,7 +343,7 @@ def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format
return (m_x, m_y), output, bench_mark, compare_result
def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs, tuning=False):
def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs, tuning=False):
batch_tuple, m, k, n = extract_dim(shape_x, shape_y, adj_x, adj_y)
m = (m + 15) // 16 * 16
n = (n + 15) // 16 * 16
......@@ -349,7 +351,7 @@ def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_for
shape_xx, shape_yy, bias_shape, out_shape, k = get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias,
left_format, right_format, output_format)
input_shapes = [shape_xx, shape_yy, bias_shape]
input_types = [dtype, dtype, out_dtype]
input_types = [dtype, dtype, bias_dtype]
has_bias = False
if bias == 1:
has_bias = True
......
......@@ -31,85 +31,85 @@ class TestCase(TestBase):
self._log.info("============= {0} Setup case============".format(self.casename))
self.testarg = [
# caseflag,opfuncname,testRunArgs, dimArgs
# shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs
# shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs
# bert shape
("matmul_run_bert_00", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")),
("matmul_run_bert_01", "matmul_run", ((8192, 4096), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_02", "matmul_run", ((8192, 1024), (1024, 4096), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")),
("matmul_run_bert_03", "matmul_run", ((16, 16), (16, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_04", "matmul_run", ((1216, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_05", "matmul_run", ((8192, 4096), (4096, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")),
("matmul_run_bert_06", "matmul_run", ((8192, 1024), (4096, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")),
("matmul_run_bert_07", "matmul_run", ((8192, 1024), (8192, 4096), 0, "zN", "zN", "zN", True, False, "float16", "float16", "matmul_cce")),
("matmul_run_bert_08", "matmul_run", ((1216, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")),
("matmul_run_bert_09", "matmul_run", ((8192, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")),
("matmul_run_bert_10", "matmul_run", ((1216, 30522), (30522, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")),
("matmul_run_bert_11", "matmul_run", ((1216, 30522), (1216, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_12", "matmul_run", ((1216, 1024), (30522, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float32", "matmul_cce")),
("matmul_run_bert_13", "matmul_run", ((8192, 1024), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_14", "matmul_run", ((1216, 1024), (1216, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float16", "matmul_cce")),
("matmul_run_bert_15", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_16", "matmul_run", ((16, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float32", "matmul_cce")),
("matmul_run_bert_17", "matmul_run", ((16, 16), (16, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float32", "matmul_cce")),
("matmul_run_bert_18", "matmul_run", ((8192, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")),
("matmul_run_bert_19", "matmul_run", ((8192, 4096), (1024, 4096), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")),
("matmul_run_bert_00", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_01", "matmul_run", ((8192, 4096), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_02", "matmul_run", ((8192, 1024), (1024, 4096), 0, "zN", "zN", "zN", False, False, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_03", "matmul_run", ((16, 16), (16, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_04", "matmul_run", ((1216, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_05", "matmul_run", ((8192, 4096), (4096, 1024), 0, "zN", "zN", "zN", False, False, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_06", "matmul_run", ((8192, 1024), (4096, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_07", "matmul_run", ((8192, 1024), (8192, 4096), 0, "zN", "zN", "zN", True, False, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_08", "matmul_run", ((1216, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_09", "matmul_run", ((8192, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, False, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_10", "matmul_run", ((1216, 30522), (30522, 1024), 0, "zN", "zN", "zN", False, False, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_11", "matmul_run", ((1216, 30522), (1216, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_12", "matmul_run", ((1216, 1024), (30522, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_13", "matmul_run", ((8192, 1024), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_14", "matmul_run", ((1216, 1024), (1216, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_15", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_16", "matmul_run", ((16, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_17", "matmul_run", ((16, 16), (16, 1024), 0, "zN", "zN", "zN", False, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_bert_18", "matmul_run", ((8192, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_19", "matmul_run", ((8192, 4096), (1024, 4096), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")),
# matmul_cast
("matmul_run1", "matmul_run",
((64, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", "float32", "matmul_cast_cce")),
((64, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", None, "float32", "matmul_cast_cce")),
# ((4, 4), (16, 16), (128, 128), (16, 16), (16, 16))),
# matmul_bias
("matmul_run2", "matmul_run",
((64, 1024), (16, 1024), 1, "zZ", "nZ", "zN", False, True, "float16", "float16", "matmul_bias_cce")),
((64, 1024), (16, 1024), 1, "zZ", "nZ", "zN", False, True, "float16", "float16", "float16", "matmul_bias_cce")),
# ((4, 4), (16, 16), (128, 128), (16, 16), (16, 16))),
# matmul_trans
("matmul_run3", "matmul_run",
((1024, 64), (16, 1024), 1, "zZ", "nZ", "zN", True, True, "float16", "float16", "matmul_bias_cce")),
((1024, 64), (16, 1024), 1, "zZ", "nZ", "zN", True, True, "float16", "float16", "float16", "matmul_bias_cce")),
# ((4, 4), (16, 16), (128, 128), (16, 16), (16, 16))),
# matmul
("matmul_run4", "matmul_run",
((64, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", "float16", "matmul_cce")),
((64, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", None, "float16", "matmul_cce")),
# ((4, 4), (16, 16), (128, 128), (16, 16), (16, 16))),
("matmul_run5", "matmul_run",
((1024, 16), (16, 1024), 1, "zZ", "nZ", "zN", False, False, "float16", "float16", "matmul_cce")),
((1024, 16), (16, 1024), 1, "zZ", "nZ", "zN", False, False, "float16", "float16", "float16", "matmul_cce")),
# ((8, 8), (8, 8), (128, 128), (128, 128), (16, 16))),
("matmul_run9", "matmul_run",
((16, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", "float16", "matmul_cce")),
((16, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", None, "float16", "matmul_cce")),
# ((16, 16), (16, 16), (16, 16))),
("matmul_run16", "matmul_run",
((16, 64), (64, 1024), 0, "zZ", "nZ", "zN", False, False, "float16", "float16", "matmul_cce")),
((16, 64), (64, 1024), 0, "zZ", "nZ", "zN", False, False, "float16", None, "float16", "matmul_cce")),
# ((16, 16), (16, 16), (16, 16), (4, 4))),
# new shape for bert
# ("matmul_run29", "matmul_run",
# ((8192,2), (1024,2), 0, 0, False, True, "float16", "float16", "matmul_cce"),
# ((8192,2), (1024,2), 0, 0, False, True, "float16", None, "float16", "matmul_cce"),
# ((8, 8), (8, 8), (128, 128), (128, 128), (16, 16))),
("matmul_run30", "matmul_run",
((64, 1024), (2, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", "float16", "matmul_cce")),
((64, 1024), (2, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", None, "float16", "matmul_cce")),
# ((4, 4), (16, 16), (16, 16), (16, 16), (16, 16))),
("matmul_run31", "matmul_run",
((2, 64), (1024, 64), 0, "zZ", "nZ", "zN", False, True, "float16", "float16", "matmul_cce")),
((2, 64), (1024, 64), 0, "zZ", "nZ", "zN", False, True, "float16", None, "float16", "matmul_cce")),
# ((16, 16), (16, 16), (16, 16), (16, 16))),
# zZ case
("matmul_run1", "matmul_run",
((6272, 256), (6272, 256), 0, "zZ", "zZ", "zZ", True, False, "float16", "float32", "matmul_cast_cce")),
((6272, 256), (6272, 256), 0, "zZ", "zZ", "zZ", True, False, "float16", None, "float32", "matmul_cast_cce")),
("matmul_run2", "matmul_run",
((6272*16, 4*16), (6272*16, 4*16), 0, "zZ", "zZ", "zZ", True, False, "float16", "float32", "matmul_cce")),
((6272*16, 4*16), (6272*16, 4*16), 0, "zZ", "zZ", "zZ", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run3", "matmul_run",
((1568*16, 8*16), (1568*16, 8*16), 0, "zZ", "zZ", "zZ", True, False, "float16", "float32", "matmul_cce")),
((1568*16, 8*16), (1568*16, 8*16), 0, "zZ", "zZ", "zZ", True, False, "float16", None, "float32", "matmul_cce")),
# zN case
("matmul_run_zN_1", "matmul_run",
((32, 48), (48, 64), 0, "zN", "zN", "zN", False, False, "float16", "float32", "matmul_cce")),
((32, 48), (48, 64), 0, "zN", "zN", "zN", False, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_zN_2", "matmul_run",
((32, 48), (48, 64), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")),
((32, 48), (48, 64), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
("matmul_run_zN_3", "matmul_run",
((32, 48), (48, 64), 0, "zN", "zN", "zN", False, True, "float16", "float32", "matmul_cce")),
((32, 48), (48, 64), 0, "zN", "zN", "zN", False, True, "float16", None, "float32", "matmul_cce")),
]
self.testarg_rpc_cloud = [
......@@ -121,11 +121,11 @@ class TestCase(TestBase):
#shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs
("matmul_run29", "matmul_run",
((8192, 16), (1024, 16), 0, "zZ", "nZ", "zN", False, True, "float16", "float16", "matmul_cce"),
((8192, 16), (1024, 16), 0, "zZ", "nZ", "zN", False, True, "float16", None, "float16", "matmul_cce"),
((8, 8), (8, 8), (128, 128), (128, 128), (128, 128))),
# ("matmul_run33", "matmul_run",
# ((16, 32), (32, 32), 0, 0, False, True, "float16", "float16", "matmul_cce"),
# ((16, 32), (32, 32), 0, 0, False, True, "float16", None, "float16", "matmul_cce"),
# ((4, 8), (4,8), (16, 128), (16, 128), (16, 128))),
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册