提交 dc28a434 编写于 作者: M Megvii Engine Team

fix(mgb/bn): fix empty tensor input problem and other minor problems

GitOrigin-RevId: 1708b76cb83e90cb91e34c58d77329a67cd4a792
上级 270f1aa2
...@@ -107,8 +107,9 @@ def test_training_converge(test_traced_module): ...@@ -107,8 +107,9 @@ def test_training_converge(test_traced_module):
optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1) optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1)
opt.step() opt.step()
losses.append(loss.numpy()) losses.append(loss.numpy())
print(np.mean(losses[-100:])) assert (
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" np.mean(losses[-100:]) < 0.1
), "Final training Loss must be low enough, get {}".format(np.mean(losses[-100:]))
ngrid = 10 ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid) x = np.linspace(-1.0, 1.0, ngrid)
...@@ -118,7 +119,6 @@ def test_training_converge(test_traced_module): ...@@ -118,7 +119,6 @@ def test_training_converge(test_traced_module):
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
pred = infer(data) pred = infer(data)
precision = calculate_precision(data.numpy(), pred.numpy()) precision = calculate_precision(data.numpy(), pred.numpy())
print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format( assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision precision
) )
...@@ -74,7 +74,6 @@ def test_save_load(): ...@@ -74,7 +74,6 @@ def test_save_load():
optim.step() optim.step()
model_name = "simple.pkl" model_name = "simple.pkl"
print("save to {}".format(model_name))
mge.save( mge.save(
{ {
...@@ -93,7 +92,6 @@ def test_save_load(): ...@@ -93,7 +92,6 @@ def test_save_load():
net.load_state_dict(checkpoint["state_dict"]) net.load_state_dict(checkpoint["state_dict"])
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.load_state_dict(checkpoint["opt_state"]) optim.load_state_dict(checkpoint["opt_state"])
print("load done")
os.remove("simple.pkl") os.remove("simple.pkl")
with gm: with gm:
......
...@@ -165,7 +165,6 @@ def test_dtype_int4_ffi_handle(): ...@@ -165,7 +165,6 @@ def test_dtype_int4_ffi_handle():
device = "xpux" device = "xpux"
shape = (3, 3, 3) shape = (3, 3, 3)
data = np.random.random(shape).astype(np.float32) * 5 - 1 data = np.random.random(shape).astype(np.float32) * 5 - 1
print(data)
def identity(x): def identity(x):
return x return x
......
...@@ -25,10 +25,6 @@ def test_basic_interface(): ...@@ -25,10 +25,6 @@ def test_basic_interface():
cf.name = "megengine.core" cf.name = "megengine.core"
cf.dtype = "float32" cf.dtype = "float32"
cf.comp_node_arr = ["xpux"] cf.comp_node_arr = ["xpux"]
print(cf.name)
print(cf.dtype)
print(cf.comp_node_arr)
print(cf.comp_node)
cf.comp_node_arr = ["xpux", "xpux:1"] cf.comp_node_arr = ["xpux", "xpux:1"]
with pytest.raises(ValueError): with pytest.raises(ValueError):
cf.comp_node cf.comp_node
......
...@@ -203,7 +203,7 @@ def test_dataloader_parallel_worker_exception(): ...@@ -203,7 +203,7 @@ def test_dataloader_parallel_worker_exception():
pass pass
def apply(self, input): def apply(self, input):
y = x + 1 raise RuntimeError("test raise error")
return input return input
dataloader = DataLoader( dataloader = DataLoader(
......
...@@ -209,7 +209,6 @@ def test_dataloader_parallel_timeout(): ...@@ -209,7 +209,6 @@ def test_dataloader_parallel_timeout():
reason="dataloader do not support parallel on windows", reason="dataloader do not support parallel on windows",
) )
def test_dataloader_parallel_worker_exception(): def test_dataloader_parallel_worker_exception():
print("in target")
dataset = init_dataset() dataset = init_dataset()
class FakeErrorTransform(Transform): class FakeErrorTransform(Transform):
...@@ -217,7 +216,7 @@ def test_dataloader_parallel_worker_exception(): ...@@ -217,7 +216,7 @@ def test_dataloader_parallel_worker_exception():
pass pass
def apply(self, input): def apply(self, input):
y = x + 1 raise RuntimeError("test raise error")
return input return input
dataloader = DataLoader( dataloader = DataLoader(
......
...@@ -103,6 +103,7 @@ def test_Compose(): ...@@ -103,6 +103,7 @@ def test_Compose():
) )
aug_data = t.apply_batch(generate_data()) aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
print(aug_data_shape)
target_shape = [((3, 90, 70), label_shape)] * 4 target_shape = [((3, 90, 70), label_shape)] * 4
assert aug_data_shape == target_shape assert aug_data_shape == target_shape, "aug {}, target {}".format(
aug_data_shape, target_shape
)
...@@ -236,8 +236,7 @@ def test_empty_tensor(is_trace): ...@@ -236,8 +236,7 @@ def test_empty_tensor(is_trace):
elif nargs == 2: elif nargs == 2:
binary_func.append([op_name, op]) binary_func.append([op_name, op])
else: else:
print(nargs) raise NotImplementedError("nargs {}".format(nargs))
raise NotImplementedError
def run_test(func, args, ref_shape, is_trace, sym=False): def run_test(func, args, ref_shape, is_trace, sym=False):
args = [tensor(t, dtype="float32") for t in args] args = [tensor(t, dtype="float32") for t in args]
...@@ -248,8 +247,7 @@ def test_empty_tensor(is_trace): ...@@ -248,8 +247,7 @@ def test_empty_tensor(is_trace):
assert out.numpy().shape == ref_shape assert out.numpy().shape == ref_shape
else: else:
out = func(*args) out = func(*args)
assert out.numpy().shape == ref_shape assert out.numpy().shape == ref_shape, out.numpy().shape
print(out.numpy().shape)
inps = [ inps = [
np.array([]).astype("float32"), np.array([]).astype("float32"),
......
...@@ -922,8 +922,8 @@ def test_layer_norm(): ...@@ -922,8 +922,8 @@ def test_layer_norm():
def test_batchnorm2d_autocast(): def test_batchnorm2d_autocast():
"""check amp's result is equal to manually converted result""" """check amp's result is equal to manually converted result"""
amp.enabled = True amp.enabled = True
tshape = (1, 224, 224, 3) tshape = (1, 3, 224, 224)
pshape = (1, 1, 1, 3) pshape = (1, 3, 1, 1)
inp = tensor(np.random.randn(*tshape), dtype=np.float32) inp = tensor(np.random.randn(*tshape), dtype=np.float32)
weight = tensor(np.ones(pshape, dtype=np.float32)) weight = tensor(np.ones(pshape, dtype=np.float32))
bias = tensor(np.zeros(pshape, dtype=np.float32)) bias = tensor(np.zeros(pshape, dtype=np.float32))
...@@ -948,7 +948,6 @@ def test_conv3d(): ...@@ -948,7 +948,6 @@ def test_conv3d():
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
out = F.conv3d(inp, weight, None, 2, 0, 1, 1) out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
print(out.numpy().shape)
np.testing.assert_equal( np.testing.assert_equal(
out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16 out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
) )
......
...@@ -230,15 +230,21 @@ void BatchNormForward::get_output_var_shape( ...@@ -230,15 +230,21 @@ void BatchNormForward::get_output_var_shape(
for (size_t i = 0; i < 4; ++ i) { for (size_t i = 0; i < 4; ++ i) {
out_shape[i] = inp_shape[1]; out_shape[i] = inp_shape[1];
} }
out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})};
if (!need_stats()) { if (!need_stats()) {
out_shape[0] = out_shape[1] = {0}; out_shape[0] = out_shape[1] = {0};
} }
if (inp_shape[0].is_empty()) {
out_shape[4] = {0};
} else {
out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})};
}
} }
size_t BatchNormForward::get_workspace_size_bytes( size_t BatchNormForward::get_workspace_size_bytes(
const TensorShapeArray &input_shapes, const TensorShapeArray &input_shapes,
const TensorShapeArray &output_shapes) const { const TensorShapeArray &output_shapes) const {
if (input_shapes[0].is_empty())
return 0;
#define in(x) {input_shapes[x], input(x)->dtype()} #define in(x) {input_shapes[x], input(x)->dtype()}
#define out(x) {output_shapes[x], output(x)->dtype()} #define out(x) {output_shapes[x], output(x)->dtype()}
return megdnn_opr()->get_workspace_in_bytes( return megdnn_opr()->get_workspace_in_bytes(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册