提交 6bf12a2e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1484 fix bug for mobilenet in model_zoo

Merge pull request !1484 from SanjayChan/mobilenet
......@@ -21,7 +21,6 @@ import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
......@@ -44,7 +43,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU":
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
raise ValueError("Unsupport platform.")
......
......@@ -32,7 +32,7 @@ from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size
import mindspore.dataset.engine as de
from src.dataset import create_dataset
from src.lr_generator import get_lr
......@@ -157,6 +157,11 @@ if __name__ == '__main__':
# train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
# define net
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU")
# define loss
......@@ -223,7 +228,7 @@ if __name__ == '__main__':
cell.to_float(mstype.float32)
if config_ascend.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_ascend.label_smooth, num_classes=config.num_classes)
smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
......
......@@ -24,7 +24,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
from src.dataset import create_dataset
from src.config import config_ascend, config_gpu
from src.mobilenetV2 import mobilenet_v2
from src.mobilenetV3 import mobilenet_v3_large
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
......@@ -49,7 +50,7 @@ if __name__ == '__main__':
loss = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
net = mobilenet_v2(num_classes=config_platform.num_classes)
net = mobilenet_v3_large(num_classes=config_platform.num_classes)
if args_opt.platform == "Ascend":
net.to_float(mstype.float16)
......
......@@ -44,7 +44,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU":
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
raise ValueError("Unsupport platform.")
......
......@@ -33,7 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend
......@@ -157,6 +157,11 @@ if __name__ == '__main__':
# train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
# define net
net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
# define loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册