diff --git a/mindspore/model_zoo/mobilenetv2/src/dataset.py b/mindspore/model_zoo/mobilenetv2/src/dataset.py index a1a77a8495a265119ebca1b7d7e1e0cde72fb373..ef86b43a604aa243864c5a74d384c476cec85424 100644 --- a/mindspore/model_zoo/mobilenetv2/src/dataset.py +++ b/mindspore/model_zoo/mobilenetv2/src/dataset.py @@ -21,7 +21,6 @@ import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 - def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32): """ create a train or eval dataset @@ -44,7 +43,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=rank_size, shard_id=rank_id) elif platform == "GPU": - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + from mindspore.communication.management import get_rank, get_group_size + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=get_group_size(), shard_id=get_rank()) else: raise ValueError("Unsupport platform.") diff --git a/mindspore/model_zoo/mobilenetv2/train.py b/mindspore/model_zoo/mobilenetv2/train.py index 9b7b63aaca1c20196c79cacef6c86f04761fb8d4..80c51380d4dbcb6f1b01cafd45f3e51e2914508a 100644 --- a/mindspore/model_zoo/mobilenetv2/train.py +++ b/mindspore/model_zoo/mobilenetv2/train.py @@ -32,7 +32,7 @@ from mindspore.train.model import Model, ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init +from mindspore.communication.management import init, get_group_size import mindspore.dataset.engine as de from src.dataset import create_dataset from src.lr_generator import get_lr @@ -157,6 +157,11 @@ if __name__ == '__main__': # train on gpu print("train args: ", args_opt, "\ncfg: ", config_gpu) + init('nccl') + context.set_auto_parallel_context(parallel_mode="data_parallel", + mirror_mean=True, + device_num=get_group_size()) + # define net net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") # define loss @@ -223,7 +228,7 @@ if __name__ == '__main__': cell.to_float(mstype.float32) if config_ascend.label_smooth > 0: loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_ascend.label_smooth, num_classes=config.num_classes) + smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) else: loss = SoftmaxCrossEntropyWithLogits( is_grad=False, sparse=True, reduction='mean') diff --git a/mindspore/model_zoo/mobilenetv3/eval.py b/mindspore/model_zoo/mobilenetv3/eval.py index 7428b748f44a0fd8f38172dd794d7ff1819e114b..e82ed496d3fbd704460bd8564f5212fb423c3662 100644 --- a/mindspore/model_zoo/mobilenetv3/eval.py +++ b/mindspore/model_zoo/mobilenetv3/eval.py @@ -24,7 +24,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import dtype as mstype from src.dataset import create_dataset from src.config import config_ascend, config_gpu -from src.mobilenetV2 import mobilenet_v2 +from src.mobilenetV3 import mobilenet_v3_large + parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -49,7 +50,7 @@ if __name__ == '__main__': loss = nn.SoftmaxCrossEntropyWithLogits( is_grad=False, sparse=True, reduction='mean') - net = mobilenet_v2(num_classes=config_platform.num_classes) + net = mobilenet_v3_large(num_classes=config_platform.num_classes) if args_opt.platform == "Ascend": net.to_float(mstype.float16) diff --git a/mindspore/model_zoo/mobilenetv3/src/dataset.py b/mindspore/model_zoo/mobilenetv3/src/dataset.py index a1a77a8495a265119ebca1b7d7e1e0cde72fb373..aa62e5f4cbd735e1bfaab4fee2a838aa5eb0de16 100644 --- a/mindspore/model_zoo/mobilenetv3/src/dataset.py +++ b/mindspore/model_zoo/mobilenetv3/src/dataset.py @@ -44,7 +44,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=rank_size, shard_id=rank_id) elif platform == "GPU": - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + from mindspore.communication.management import get_rank, get_group_size + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=get_group_size(), shard_id=get_rank()) else: raise ValueError("Unsupport platform.") diff --git a/mindspore/model_zoo/mobilenetv3/train.py b/mindspore/model_zoo/mobilenetv3/train.py index b11f1dc6e75c220ab2e63b766b270df4ed118fc8..724fed7cb84c1ec1be9c91d20f5e13eddbfecad7 100644 --- a/mindspore/model_zoo/mobilenetv3/train.py +++ b/mindspore/model_zoo/mobilenetv3/train.py @@ -33,7 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset.engine as de -from mindspore.communication.management import init +from mindspore.communication.management import init, get_group_size from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_gpu, config_ascend @@ -157,6 +157,11 @@ if __name__ == '__main__': # train on gpu print("train args: ", args_opt, "\ncfg: ", config_gpu) + init('nccl') + context.set_auto_parallel_context(parallel_mode="data_parallel", + mirror_mean=True, + device_num=get_group_size()) + # define net net = mobilenet_v3_large(num_classes=config_gpu.num_classes) # define loss