提交 ddf51155 编写于 作者: H hanhuifeng2020

fix failed test cases apply_adagrad,apply_rms_prop,four2five

上级 4452f2eb
......@@ -270,10 +270,10 @@ def four2five(data, format_, dst_dtype='float16'):
pad_before.append(0)
pad_after.append(0)
pad_after[-1] = last_channel - c
output = akg.topi.reshape(cast_data, (bs, c1, h, w, c))
# As c < last_channel, c1 is 1
output = akg.tvm.compute((bs, c1, h, w, c), lambda bs_i, _, h_i, w_i, c_i: cast_data[
bs_i, h_i, w_i, c_i], name="output")
output = tvm_pad(output, pad_before, pad_after=pad_after, name='pad_output')
# In this case, reshape will create mod/div ops, which needs loop partition for tiling
attrs["enable_pre_poly_loop_partition"] = True
else:
output = nhwc_to_nc1hwc0(
cast_data,
......
......@@ -25,10 +25,14 @@ from akg.ops.math.rsqrt import rsqrt
def _apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon):
"""Compute apply_rms_prop"""
shape = get_shape(var)
compute_dtype = "float32"
dtype = var.dtype
cons_eps = akg.tvm.const(epsilon, dtype=dtype)
one_minus_rho = akg.tvm.compute((1, ), lambda *indice: akg.tvm.const(1.0, dtype) - rho[0], name="one_minus_rho")
if dtype != compute_dtype:
var, ms, mom, grad, lr, momentum, rho = [topi.cast(t, compute_dtype) for t in [
var, ms, mom, grad, lr, momentum, rho]]
shape = get_shape(var)
cons_eps = akg.tvm.const(epsilon, dtype=compute_dtype)
one_minus_rho = akg.tvm.compute((1, ), lambda *indice: akg.tvm.const(1.0, compute_dtype) - rho[0], name="one_minus_rho")
# var_update = var - (momentum * mom + lr * grad / sqrt(rho * ms + (1 - rho) * grad * grad + epsilon))
mom_1 = akg.tvm.compute(shape, lambda *indice: momentum[0] * mom(*indice), name="mom_1")
......@@ -42,6 +46,8 @@ def _apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon):
mom_2 = akg.tvm.compute(shape, lambda *indice: lr_grad(*indice) * rsq(*indice), name="mom_2")
mom_update = akg.tvm.compute(shape, lambda *indice: mom_1(*indice) + mom_2(*indice), name="mom_update")
var_update = akg.tvm.compute(shape, lambda *indice: var(*indice) - mom_update(*indice), name="var_update")
if var_update.dtype != dtype:
var_update, ms_update, mom_update = [topi.cast(t, dtype) for t in [var_update, ms_update, mom_update]]
return var_update, ms_update, mom_update
......
......@@ -34,13 +34,12 @@ def apply_adagrad_execute(shape, dtype, update_slots, attrs=None):
def gen_data(dtype, update_slots, shape):
var = random_gaussian(shape, miu=1, sigma=0.1).astype(dtype)
accum = random_gaussian(shape, miu=1, sigma=0.1).astype(dtype)
# accum must be greater than or equal to 0
accum = np.abs(random_gaussian(shape, miu=1, sigma=0.1).astype(dtype))
lr = random_gaussian((1,), miu=1, sigma=0.1).astype(dtype)
grad = random_gaussian(shape, miu=1, sigma=0.1).astype(dtype)
inputs = [var, accum, lr, grad]
accum_out = accum + grad * grad if update_slots else accum
var_out = var - (lr * grad / np.sqrt(accum_out))
exp_output = (var_out, accum_out)
exp_output = apply_adagrad_compute(var, accum, lr, grad, update_slots)
outputs = [np.full(e.shape, np.nan, dtype) for e in exp_output]
args = [*inputs, *outputs]
......@@ -52,3 +51,16 @@ def apply_adagrad_compile(shape, dtype, update_slots, attrs, kernel_name="apply_
dtypes = [dtype] * len(shapes)
return utils.op_build_test(apply_adagrad.apply_adagrad, shapes, dtypes, [update_slots],
kernel_name=kernel_name, attrs=attrs, tuning=tuning)
def apply_adagrad_compute(var, accum, lr, grad, update_slots):
dtype = var.dtype
compute_dtype = "float32"
if dtype != compute_dtype:
var, accum, lr, grad = [t.astype(compute_dtype) for t in [var, accum, lr, grad]]
accum_out = accum + grad * grad if update_slots else accum
var_out = var - (lr * grad / np.sqrt(accum_out))
exp_output = [var_out, accum_out]
if compute_dtype != dtype:
exp_output = [t.astype(dtype) for t in exp_output]
return exp_output
......@@ -61,13 +61,22 @@ def gen_data(shape, dtype, lr, momentum, rho, epsilon):
lr = np.array([lr]).astype(dtype)
momentum = np.array([momentum]).astype(dtype)
rho = np.array([rho]).astype(dtype)
inputs = [var, ms, mom, grad, lr, momentum, rho]
expects = apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon)
args = inputs
return inputs, expects, args
def apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon):
compute_dtype = "float32"
dtype = var.dtype
if dtype != compute_dtype:
var, ms, mom, grad, lr, momentum, rho = [t.astype(compute_dtype) for t in [
var, ms, mom, grad, lr, momentum, rho]]
# ms = rho * ms + (1-rho) * grad * grad
# mom = momentum * mom + lr * grad / sqrt(ms + epsilon)
# var = var - mom
one = np.array([1.0]).astype(dtype)
one = np.array([1.0]).astype(compute_dtype)
ms_1 = rho * ms
ms_2 = (one - rho) * grad * grad
ms_update = ms_1 + ms_2
......@@ -77,7 +86,7 @@ def gen_data(shape, dtype, lr, momentum, rho, epsilon):
mom_3 = mom_2_1 * mom_2_2
mom_update = mom_1 + mom_3
var_update = var - mom_update
expects = (var_update, ms_update, mom_update)
args = inputs
return inputs, expects, args
expects = [var_update, ms_update, mom_update]
if var_update.dtype != dtype:
expects = [t.astype(dtype) for t in expects]
return expects
......@@ -29,7 +29,7 @@ class TestCase(TestBase):
self.testarg = [
("apply_adagrad_001", "apply_adagrad_run", ((16, 16), "float16", True)),
("apply_adagrad_002", "apply_adagrad_run", ((16, 16), "float32", True)),
# ("apply_adagrad_003", "apply_adagrad_run", ((16, 16), "float16", False)),
("apply_adagrad_003", "apply_adagrad_run", ((16, 16), "float16", False)),
("apply_adagrad_004", "apply_adagrad_run", ((16, 16), "float32", False)),
]
......
......@@ -46,7 +46,7 @@ class TestCase(TestBase):
self.testarg = [
# testflag, opfuncname, testRunArgs, dimArgs
# testRunArgs: (shape, dtype, lr, momentum, rho, epsilon, attrs)
# ("apply_rms_prop_1", apply_rms_prop_run, ((1024,), "float16", 0.5, 0.9, 0.6, 1e-4)),
("apply_rms_prop_1", apply_rms_prop_run, ((1024,), "float16", 0.5, 0.9, 0.6, 1e-4)),
("apply_rms_prop_2", apply_rms_prop_run, ((16, 16), "float32", 0.5, 0.9, 0.6, 1e-6)),
]
self.testarg_cloud = [
......
......@@ -72,7 +72,7 @@ class TestCase(TestBase):
#("four2five_025", four2five_run, ([8, 64, 16, 16], "float32", 'NHWC', 'float16')),
#("four2five_026", four2five_run, ([1, 64, 15, 15], "float32", 'NHWC', 'float16')),
#("four2five_027", four2five_run, ([1, 24, 16, 16], "float32", 'NHWC', 'float16')),
#("four2five_028", "four2five_run", ([1, 59, 121, 15], "float32", 'NHWC', 'float32')),
("four2five_028", "four2five_run", ([1, 59, 121, 15], "float32", 'NHWC', 'float32')),
("four2five_017", "four2five_run", ([32, 2048, 7, 7], "float32", 'NCHW', 'float32')),
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册