提交 6bf12a2e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1484 fix bug for mobilenet in model_zoo

Merge pull request !1484 from SanjayChan/mobilenet
...@@ -21,7 +21,6 @@ import mindspore.dataset.engine as de ...@@ -21,7 +21,6 @@ import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32): def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32):
""" """
create a train or eval dataset create a train or eval dataset
...@@ -44,7 +43,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -44,7 +43,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id) num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU": elif platform == "GPU":
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
......
...@@ -32,7 +32,7 @@ from mindspore.train.model import Model, ParallelMode ...@@ -32,7 +32,7 @@ from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init from mindspore.communication.management import init, get_group_size
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
...@@ -157,6 +157,11 @@ if __name__ == '__main__': ...@@ -157,6 +157,11 @@ if __name__ == '__main__':
# train on gpu # train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu) print("train args: ", args_opt, "\ncfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
# define net # define net
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU")
# define loss # define loss
...@@ -223,7 +228,7 @@ if __name__ == '__main__': ...@@ -223,7 +228,7 @@ if __name__ == '__main__':
cell.to_float(mstype.float32) cell.to_float(mstype.float32)
if config_ascend.label_smooth > 0: if config_ascend.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth( loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_ascend.label_smooth, num_classes=config.num_classes) smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes)
else: else:
loss = SoftmaxCrossEntropyWithLogits( loss = SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean') is_grad=False, sparse=True, reduction='mean')
......
...@@ -24,7 +24,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net ...@@ -24,7 +24,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from src.dataset import create_dataset from src.dataset import create_dataset
from src.config import config_ascend, config_gpu from src.config import config_ascend, config_gpu
from src.mobilenetV2 import mobilenet_v2 from src.mobilenetV3 import mobilenet_v3_large
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
...@@ -49,7 +50,7 @@ if __name__ == '__main__': ...@@ -49,7 +50,7 @@ if __name__ == '__main__':
loss = nn.SoftmaxCrossEntropyWithLogits( loss = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean') is_grad=False, sparse=True, reduction='mean')
net = mobilenet_v2(num_classes=config_platform.num_classes) net = mobilenet_v3_large(num_classes=config_platform.num_classes)
if args_opt.platform == "Ascend": if args_opt.platform == "Ascend":
net.to_float(mstype.float16) net.to_float(mstype.float16)
......
...@@ -44,7 +44,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -44,7 +44,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id) num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU": elif platform == "GPU":
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
......
...@@ -33,7 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback ...@@ -33,7 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from mindspore.communication.management import init from mindspore.communication.management import init, get_group_size
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend from src.config import config_gpu, config_ascend
...@@ -157,6 +157,11 @@ if __name__ == '__main__': ...@@ -157,6 +157,11 @@ if __name__ == '__main__':
# train on gpu # train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu) print("train args: ", args_opt, "\ncfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
# define net # define net
net = mobilenet_v3_large(num_classes=config_gpu.num_classes) net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
# define loss # define loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册