From d6278c2bc6fdfe98622e7e53b5a7eceea2dbfbdc Mon Sep 17 00:00:00 2001 From: Payne Date: Mon, 7 Sep 2020 17:07:08 +0800 Subject: [PATCH] train with ascend, modify api and debug --- model_zoo/official/cv/mobilenetv2/eval.py | 4 ++-- model_zoo/official/cv/mobilenetv2/scripts/run_train.sh | 4 ++-- model_zoo/official/cv/mobilenetv2/src/args.py | 4 ++-- model_zoo/official/cv/mobilenetv2/src/launch.py | 4 ++-- model_zoo/official/cv/mobilenetv2/src/utils.py | 3 ++- model_zoo/official/cv/mobilenetv2/train.py | 6 +++--- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/model_zoo/official/cv/mobilenetv2/eval.py b/model_zoo/official/cv/mobilenetv2/eval.py index 967eae9b9..c50947d1a 100644 --- a/model_zoo/official/cv/mobilenetv2/eval.py +++ b/model_zoo/official/cv/mobilenetv2/eval.py @@ -32,7 +32,7 @@ if __name__ == '__main__': backbone_net = MobileNetV2Backbone(platform=args_opt.platform) head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) - net = mobilenet_v2(feature_net, head_net) + net = mobilenet_v2(backbone_net, head_net) #load the trained checkpoint file to the net for evaluation if args_opt.head_ckpt: @@ -51,7 +51,7 @@ if __name__ == '__main__': loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') model = Model(net, loss_fn=loss, metrics={'acc'}) - res = model.eval(dataset) + res = model.eval(dataset, dataset_sink_mode=False) print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") if args_opt.head_ckpt: print(f"head_ckpt={args_opt.head_ckpt}") diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh index 5d4f75445..c7eea7ef9 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh @@ -84,9 +84,9 @@ run_gpu() run_cpu() { - if [ ! -d $4 ] + if [ ! -d $2 ] then - echo "error: DATASET_PATH=$4 is not a directory" + echo "error: DATASET_PATH=$2 is not a directory" exit 1 fi diff --git a/model_zoo/official/cv/mobilenetv2/src/args.py b/model_zoo/official/cv/mobilenetv2/src/args.py index 184b65a63..7f192d681 100644 --- a/model_zoo/official/cv/mobilenetv2/src/args.py +++ b/model_zoo/official/cv/mobilenetv2/src/args.py @@ -22,7 +22,7 @@ def launch_parse_args(): that will spawn up multiple distributed processes") launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ help='run platform, only support GPU, CPU and Ascend') - launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0, 1, 2, 3, 4, 5, 6, 7), \ + launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(1, 2, 3, 4, 5, 6, 7, 8), \ help="The number of processes to launch on each node, for D training, this is recommended to be set \ to the number of D in your system so that each process can be bound to a single D.") launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \ @@ -32,7 +32,7 @@ def launch_parse_args(): the training script") launch_args, unknown = launch_parser.parse_known_args() - launch_args.train_script_args = unknown + launch_args.training_script_args = unknown launch_args.training_script_args += ["--platform", launch_args.platform] return launch_args diff --git a/model_zoo/official/cv/mobilenetv2/src/launch.py b/model_zoo/official/cv/mobilenetv2/src/launch.py index 0b42a5d75..8785186dc 100644 --- a/model_zoo/official/cv/mobilenetv2/src/launch.py +++ b/model_zoo/official/cv/mobilenetv2/src/launch.py @@ -46,8 +46,8 @@ def main(): os.mkdir(device_dir) os.chdir(device_dir) cmd = [sys.executable, '-u'] - cmd.append(args.train_script) - cmd.extend(args.train_script_args) + cmd.append(args.training_script) + cmd.extend(args.training_script_args) log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) processes.append(process) diff --git a/model_zoo/official/cv/mobilenetv2/src/utils.py b/model_zoo/official/cv/mobilenetv2/src/utils.py index 5a05f397a..0c260a7b0 100644 --- a/model_zoo/official/cv/mobilenetv2/src/utils.py +++ b/model_zoo/official/cv/mobilenetv2/src/utils.py @@ -17,8 +17,9 @@ from mindspore import context from mindspore import nn from mindspore.common import dtype as mstype from mindspore.train.model import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.communication.management import get_rank, init +from mindspore.communication.management import get_rank, init, get_group_size from src.models import Monitor diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index f817b2b7c..b417d72ce 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.common import dtype as mstype from mindspore.train.model import Model from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import _exec_save_checkpoint +from mindspore.train.serialization import save_checkpoint from mindspore.common import set_seed from src.dataset import create_dataset, extract_features @@ -88,7 +88,7 @@ if __name__ == '__main__': opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) network = WithLossCell(net, loss) - network = TrainOneStepCell(net, opt) + network = TrainOneStepCell(network, opt) network.set_train() features_path = args_opt.dataset_path + '_features' @@ -116,7 +116,7 @@ if __name__ == '__main__': .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ end="") if (epoch + 1) % config.save_checkpoint_epochs == 0: - _exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ + save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ f"mobilenetv2_head_{epoch+1}.ckpt")) print("total cost {:5.4f} s".format(time.time() - start)) -- GitLab