From 668c18c42c4523efac60464e8561c9fe1f07e371 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Thu, 21 May 2020 15:42:52 +0800 Subject: [PATCH] fix cifar 1p test --- tests/st/tbe_networks/test_resnet_cifar_1p.py | 35 +++---------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/tests/st/tbe_networks/test_resnet_cifar_1p.py b/tests/st/tbe_networks/test_resnet_cifar_1p.py index 319338252..92954998c 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_1p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_1p.py @@ -134,12 +134,8 @@ class LossGet(Callback): return self._loss -def train_process(device_id, epoch_size, num_classes, batch_size): - os.system("mkdir " + str(device_id)) - os.chdir(str(device_id)) +def train_process(epoch_size, num_classes, batch_size): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - context.set_context(device_id=device_id) - context.set_context(mode=context.GRAPH_MODE) net = resnet50(batch_size, num_classes) loss = CrossEntropyLoss() opt = Momentum(filter(lambda x: x.requires_grad, @@ -148,34 +144,15 @@ def train_process(device_id, epoch_size, num_classes, batch_size): model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) - batch_num = dataset.get_dataset_size() - config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1) - ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./", - config=config_ck) loss_cb = LossGet() - model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) + model.train(epoch_size, dataset, callbacks=[loss_cb]) - -def eval(batch_size, num_classes): - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - context.set_context(device_id=0) - - net = resnet50(batch_size, num_classes) - loss = CrossEntropyLoss() - opt = Momentum(filter(lambda x: x.requires_grad, - net.get_parameters()), 0.01, 0.9) - - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt" - param_dict = load_checkpoint(checkpoint_path) - load_param_into_net(net, param_dict) net.set_train(False) eval_dataset = create_dataset(1, training=False) res = model.eval(eval_dataset) print("result: ", res) return res - @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -184,11 +161,7 @@ def test_resnet_cifar_1p(): epoch_size = 1 num_classes = 10 batch_size = 32 - device_id = 0 - train_process(device_id, epoch_size, num_classes, batch_size) - time.sleep(3) - acc = eval(batch_size, num_classes) - os.chdir("../") - os.system("rm -rf " + str(device_id)) + acc = train_process(epoch_size, num_classes, batch_size) + os.system("rm -rf kernel_meta") print("End training...") assert acc['acc'] > 0.35 -- GitLab