提交 e08e4088 编写于 作者: C chujinjin

fix model zoo error for pynative

上级 8a71db07
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
import argparse import argparse
import ast
def launch_parse_args(): def launch_parse_args():
...@@ -43,6 +43,7 @@ def train_parse_args(): ...@@ -43,6 +43,7 @@ def train_parse_args():
help='run platform, only support CPU, GPU and Ascend') help='run platform, only support CPU, GPU and Ascend')
train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning') for fine tune or incremental learning')
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \
help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \
train from initialization model") train from initialization model")
......
...@@ -59,6 +59,7 @@ def set_config(args): ...@@ -59,6 +59,7 @@ def set_config(args):
"save_checkpoint_path": "./checkpoint", "save_checkpoint_path": "./checkpoint",
"platform": args.platform, "platform": args.platform,
"ccl": "nccl", "ccl": "nccl",
"run_distribute": args.run_distribute
}) })
config_ascend = ed({ config_ascend = ed({
"num_classes": 1000, "num_classes": 1000,
......
...@@ -51,11 +51,14 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): ...@@ -51,11 +51,14 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1):
num_shards=rank_size, shard_id=rank_id) num_shards=rank_size, shard_id=rank_id)
elif config.platform == "GPU": elif config.platform == "GPU":
if do_train: if do_train:
if config.run_distribute:
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank()) num_shards=get_group_size(), shard_id=get_rank())
else: else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
elif config.platform == "CPU": elif config.platform == "CPU":
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
......
...@@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size ...@@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size
from src.models import Monitor from src.models import Monitor
def switch_precision(net, data_type, config): def switch_precision(net, data_type, config):
if config.platform == "Ascend": if config.platform == "Ascend":
net.to_float(data_type) net.to_float(data_type)
...@@ -29,13 +30,14 @@ def switch_precision(net, data_type, config): ...@@ -29,13 +30,14 @@ def switch_precision(net, data_type, config):
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32) cell.to_float(mstype.float32)
def context_device_init(config):
def context_device_init(config):
if config.platform == "CPU": if config.platform == "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
elif config.platform == "GPU": elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
if config.run_distribute:
init("nccl") init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
...@@ -53,6 +55,7 @@ def context_device_init(config): ...@@ -53,6 +55,7 @@ def context_device_init(config):
else: else:
raise ValueError("Only support CPU, GPU and Ascend.") raise ValueError("Only support CPU, GPU and Ascend.")
def set_context(config): def set_context(config):
if config.platform == "CPU": if config.platform == "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
...@@ -64,6 +67,7 @@ def set_context(config): ...@@ -64,6 +67,7 @@ def set_context(config):
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target=config.platform, save_graphs=False) device_target=config.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size): def config_ckpoint(config, lr, step_size):
cb = None cb = None
if config.platform in ("CPU", "GPU") or config.rank_id == 0: if config.platform in ("CPU", "GPU") or config.rank_id == 0:
...@@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size): ...@@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size):
ckpt_save_dir = config.save_checkpoint_path ckpt_save_dir = config.save_checkpoint_path
if config.platform == "GPU": if config.platform == "GPU":
if config.run_distribute:
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
else:
ckpt_save_dir += "ckpt_" + "/"
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]
......
...@@ -21,7 +21,7 @@ import mindspore.dataset.vision.c_transforms as C ...@@ -21,7 +21,7 @@ import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32, run_distribute=False):
""" """
create a train or eval dataset create a train or eval dataset
...@@ -36,11 +36,14 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, ...@@ -36,11 +36,14 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
""" """
if device_target == "GPU": if device_target == "GPU":
if do_train: if do_train:
if run_distribute:
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank()) num_shards=get_group_size(), shard_id=get_rank())
else: else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
raise ValueError("Unsupported device_target.") raise ValueError("Unsupported device_target.")
...@@ -56,7 +59,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, ...@@ -56,7 +59,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
resize_op = C.Resize(256) resize_op = C.Resize(256)
center_crop = C.CenterCrop(resize_width) center_crop = C.CenterCrop(resize_width)
rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
change_swap_op = C.HWC2CHW() change_swap_op = C.HWC2CHW()
if do_train: if do_train:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import time import time
import argparse import argparse
import ast
import numpy as np import numpy as np
from mindspore import context from mindspore import context
...@@ -46,12 +47,14 @@ parser = argparse.ArgumentParser(description='Image classification') ...@@ -46,12 +47,14 @@ parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--device_target', type=str, default="GPU", help='run device_target') parser.add_argument('--device_target', type=str, default="GPU", help='run device_target')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
args_opt = parser.parse_args() args_opt = parser.parse_args()
if args_opt.device_target == "GPU": if args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", device_target="GPU",
save_graphs=False) save_graphs=False)
if args_opt.run_distribute:
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
...@@ -168,7 +171,8 @@ if __name__ == '__main__': ...@@ -168,7 +171,8 @@ if __name__ == '__main__':
config=config_gpu, config=config_gpu,
device_target=args_opt.device_target, device_target=args_opt.device_target,
repeat_num=1, repeat_num=1,
batch_size=config_gpu.batch_size) batch_size=config_gpu.batch_size,
run_distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# resume # resume
if args_opt.pre_trained: if args_opt.pre_trained:
...@@ -191,7 +195,10 @@ if __name__ == '__main__': ...@@ -191,7 +195,10 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale) loss_scale_manager=loss_scale)
cb = [Monitor(lr_init=lr.asnumpy())] cb = [Monitor(lr_init=lr.asnumpy())]
if args_opt.run_distribute:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
else:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/"
if config_gpu.save_checkpoint: if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max) keep_checkpoint_max=config_gpu.keep_checkpoint_max)
......
...@@ -399,6 +399,6 @@ class PredictWithSigmoid(nn.Cell): ...@@ -399,6 +399,6 @@ class PredictWithSigmoid(nn.Cell):
self.sigmoid = P.Sigmoid() self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels): def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts) logits, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits) pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels return logits, pred_probs, labels
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册