提交 dc28a434 编写于 作者: M Megvii Engine Team

fix(mgb/bn): fix empty tensor input problem and other minor problems

GitOrigin-RevId: 1708b76cb83e90cb91e34c58d77329a67cd4a792
上级 270f1aa2
......@@ -107,8 +107,9 @@ def test_training_converge(test_traced_module):
optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1)
opt.step()
losses.append(loss.numpy())
print(np.mean(losses[-100:]))
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
assert (
np.mean(losses[-100:]) < 0.1
), "Final training Loss must be low enough, get {}".format(np.mean(losses[-100:]))
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
......@@ -118,7 +119,6 @@ def test_training_converge(test_traced_module):
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
pred = infer(data)
precision = calculate_precision(data.numpy(), pred.numpy())
print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
)
......@@ -74,7 +74,6 @@ def test_save_load():
optim.step()
model_name = "simple.pkl"
print("save to {}".format(model_name))
mge.save(
{
......@@ -93,7 +92,6 @@ def test_save_load():
net.load_state_dict(checkpoint["state_dict"])
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.load_state_dict(checkpoint["opt_state"])
print("load done")
os.remove("simple.pkl")
with gm:
......
......@@ -165,7 +165,6 @@ def test_dtype_int4_ffi_handle():
device = "xpux"
shape = (3, 3, 3)
data = np.random.random(shape).astype(np.float32) * 5 - 1
print(data)
def identity(x):
return x
......
......@@ -25,10 +25,6 @@ def test_basic_interface():
cf.name = "megengine.core"
cf.dtype = "float32"
cf.comp_node_arr = ["xpux"]
print(cf.name)
print(cf.dtype)
print(cf.comp_node_arr)
print(cf.comp_node)
cf.comp_node_arr = ["xpux", "xpux:1"]
with pytest.raises(ValueError):
cf.comp_node
......
......@@ -203,7 +203,7 @@ def test_dataloader_parallel_worker_exception():
pass
def apply(self, input):
y = x + 1
raise RuntimeError("test raise error")
return input
dataloader = DataLoader(
......
......@@ -209,7 +209,6 @@ def test_dataloader_parallel_timeout():
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_worker_exception():
print("in target")
dataset = init_dataset()
class FakeErrorTransform(Transform):
......@@ -217,7 +216,7 @@ def test_dataloader_parallel_worker_exception():
pass
def apply(self, input):
y = x + 1
raise RuntimeError("test raise error")
return input
dataloader = DataLoader(
......
......@@ -103,6 +103,7 @@ def test_Compose():
)
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
print(aug_data_shape)
target_shape = [((3, 90, 70), label_shape)] * 4
assert aug_data_shape == target_shape
assert aug_data_shape == target_shape, "aug {}, target {}".format(
aug_data_shape, target_shape
)
......@@ -236,8 +236,7 @@ def test_empty_tensor(is_trace):
elif nargs == 2:
binary_func.append([op_name, op])
else:
print(nargs)
raise NotImplementedError
raise NotImplementedError("nargs {}".format(nargs))
def run_test(func, args, ref_shape, is_trace, sym=False):
args = [tensor(t, dtype="float32") for t in args]
......@@ -248,8 +247,7 @@ def test_empty_tensor(is_trace):
assert out.numpy().shape == ref_shape
else:
out = func(*args)
assert out.numpy().shape == ref_shape
print(out.numpy().shape)
assert out.numpy().shape == ref_shape, out.numpy().shape
inps = [
np.array([]).astype("float32"),
......
......@@ -922,8 +922,8 @@ def test_layer_norm():
def test_batchnorm2d_autocast():
"""check amp's result is equal to manually converted result"""
amp.enabled = True
tshape = (1, 224, 224, 3)
pshape = (1, 1, 1, 3)
tshape = (1, 3, 224, 224)
pshape = (1, 3, 1, 1)
inp = tensor(np.random.randn(*tshape), dtype=np.float32)
weight = tensor(np.ones(pshape, dtype=np.float32))
bias = tensor(np.zeros(pshape, dtype=np.float32))
......@@ -948,7 +948,6 @@ def test_conv3d():
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
print(out.numpy().shape)
np.testing.assert_equal(
out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
)
......
......@@ -230,15 +230,21 @@ void BatchNormForward::get_output_var_shape(
for (size_t i = 0; i < 4; ++ i) {
out_shape[i] = inp_shape[1];
}
out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})};
if (!need_stats()) {
out_shape[0] = out_shape[1] = {0};
}
if (inp_shape[0].is_empty()) {
out_shape[4] = {0};
} else {
out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})};
}
}
size_t BatchNormForward::get_workspace_size_bytes(
const TensorShapeArray &input_shapes,
const TensorShapeArray &output_shapes) const {
if (input_shapes[0].is_empty())
return 0;
#define in(x) {input_shapes[x], input(x)->dtype()}
#define out(x) {output_shapes[x], output(x)->dtype()}
return megdnn_opr()->get_workspace_in_bytes(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册