提交 77b1f718 编写于 作者: V VectorSL

gpu fix resnet script

上级 0b4de001
......@@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import get_rank, get_group_size
from mindspore.communication.management import init, get_rank, get_group_size
from config import config
......
......@@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import get_rank, get_group_size
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
......@@ -40,6 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册