提交 61ac727c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1550 bug fix while evaluation

Merge pull request !1550 from SanjayChan/mobilenet
...@@ -28,8 +28,8 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -28,8 +28,8 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
Args: Args:
dataset_path(string): the path of dataset. dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval. do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1 repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32 batch_size(int): the batch size of dataset. Default: 32.
Returns: Returns:
dataset dataset
...@@ -43,9 +43,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -43,9 +43,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id) num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU": elif platform == "GPU":
from mindspore.communication.management import get_rank, get_group_size if do_train:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, from mindspore.communication.management import get_rank, get_group_size
num_shards=get_group_size(), shard_id=get_rank()) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
......
...@@ -44,9 +44,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -44,9 +44,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id) num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU": elif platform == "GPU":
from mindspore.communication.management import get_rank, get_group_size if do_train:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, from mindspore.communication.management import get_rank, get_group_size
num_shards=get_group_size(), shard_id=get_rank()) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册