提交 d6278c2b 编写于 作者: P Payne

train with ascend, modify api and debug

上级 b9345d1d
...@@ -32,7 +32,7 @@ if __name__ == '__main__': ...@@ -32,7 +32,7 @@ if __name__ == '__main__':
backbone_net = MobileNetV2Backbone(platform=args_opt.platform) backbone_net = MobileNetV2Backbone(platform=args_opt.platform)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(feature_net, head_net) net = mobilenet_v2(backbone_net, head_net)
#load the trained checkpoint file to the net for evaluation #load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt: if args_opt.head_ckpt:
...@@ -51,7 +51,7 @@ if __name__ == '__main__': ...@@ -51,7 +51,7 @@ if __name__ == '__main__':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={'acc'}) model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset) res = model.eval(dataset, dataset_sink_mode=False)
print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}")
if args_opt.head_ckpt: if args_opt.head_ckpt:
print(f"head_ckpt={args_opt.head_ckpt}") print(f"head_ckpt={args_opt.head_ckpt}")
...@@ -84,9 +84,9 @@ run_gpu() ...@@ -84,9 +84,9 @@ run_gpu()
run_cpu() run_cpu()
{ {
if [ ! -d $4 ] if [ ! -d $2 ]
then then
echo "error: DATASET_PATH=$4 is not a directory" echo "error: DATASET_PATH=$2 is not a directory"
exit 1 exit 1
fi fi
......
...@@ -22,7 +22,7 @@ def launch_parse_args(): ...@@ -22,7 +22,7 @@ def launch_parse_args():
that will spawn up multiple distributed processes") that will spawn up multiple distributed processes")
launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend') help='run platform, only support GPU, CPU and Ascend')
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0, 1, 2, 3, 4, 5, 6, 7), \ launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(1, 2, 3, 4, 5, 6, 7, 8), \
help="The number of processes to launch on each node, for D training, this is recommended to be set \ help="The number of processes to launch on each node, for D training, this is recommended to be set \
to the number of D in your system so that each process can be bound to a single D.") to the number of D in your system so that each process can be bound to a single D.")
launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \ launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \
...@@ -32,7 +32,7 @@ def launch_parse_args(): ...@@ -32,7 +32,7 @@ def launch_parse_args():
the training script") the training script")
launch_args, unknown = launch_parser.parse_known_args() launch_args, unknown = launch_parser.parse_known_args()
launch_args.train_script_args = unknown launch_args.training_script_args = unknown
launch_args.training_script_args += ["--platform", launch_args.platform] launch_args.training_script_args += ["--platform", launch_args.platform]
return launch_args return launch_args
......
...@@ -46,8 +46,8 @@ def main(): ...@@ -46,8 +46,8 @@ def main():
os.mkdir(device_dir) os.mkdir(device_dir)
os.chdir(device_dir) os.chdir(device_dir)
cmd = [sys.executable, '-u'] cmd = [sys.executable, '-u']
cmd.append(args.train_script) cmd.append(args.training_script)
cmd.extend(args.train_script_args) cmd.extend(args.training_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process) processes.append(process)
......
...@@ -17,8 +17,9 @@ from mindspore import context ...@@ -17,8 +17,9 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init from mindspore.communication.management import get_rank, init, get_group_size
from src.models import Monitor from src.models import Monitor
......
...@@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import _exec_save_checkpoint from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed from mindspore.common import set_seed
from src.dataset import create_dataset, extract_features from src.dataset import create_dataset, extract_features
...@@ -88,7 +88,7 @@ if __name__ == '__main__': ...@@ -88,7 +88,7 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
network = WithLossCell(net, loss) network = WithLossCell(net, loss)
network = TrainOneStepCell(net, opt) network = TrainOneStepCell(network, opt)
network.set_train() network.set_train()
features_path = args_opt.dataset_path + '_features' features_path = args_opt.dataset_path + '_features'
...@@ -116,7 +116,7 @@ if __name__ == '__main__': ...@@ -116,7 +116,7 @@ if __name__ == '__main__':
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
end="") end="")
if (epoch + 1) % config.save_checkpoint_epochs == 0: if (epoch + 1) % config.save_checkpoint_epochs == 0:
_exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
f"mobilenetv2_head_{epoch+1}.ckpt")) f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start)) print("total cost {:5.4f} s".format(time.time() - start))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册