提交 e08e4088 编写于 作者: C chujinjin

fix model zoo error for pynative

上级 8a71db07
......@@ -14,7 +14,7 @@
# ============================================================================
import argparse
import ast
def launch_parse_args():
......@@ -43,6 +43,7 @@ def train_parse_args():
help='run platform, only support CPU, GPU and Ascend')
train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning')
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \
help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \
train from initialization model")
......
......@@ -59,6 +59,7 @@ def set_config(args):
"save_checkpoint_path": "./checkpoint",
"platform": args.platform,
"ccl": "nccl",
"run_distribute": args.run_distribute
})
config_ascend = ed({
"num_classes": 1000,
......
......@@ -51,9 +51,12 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1):
num_shards=rank_size, shard_id=rank_id)
elif config.platform == "GPU":
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
if config.run_distribute:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
elif config.platform == "CPU":
......
......@@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size
from src.models import Monitor
def switch_precision(net, data_type, config):
if config.platform == "Ascend":
net.to_float(data_type)
......@@ -29,17 +30,18 @@ def switch_precision(net, data_type, config):
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
def context_device_init(config):
def context_device_init(config):
if config.platform == "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
if config.run_distribute:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
elif config.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
......@@ -53,6 +55,7 @@ def context_device_init(config):
else:
raise ValueError("Only support CPU, GPU and Ascend.")
def set_context(config):
if config.platform == "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
......@@ -64,6 +67,7 @@ def set_context(config):
context.set_context(mode=context.GRAPH_MODE,
device_target=config.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size):
cb = None
if config.platform in ("CPU", "GPU") or config.rank_id == 0:
......@@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size):
ckpt_save_dir = config.save_checkpoint_path
if config.platform == "GPU":
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
if config.run_distribute:
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
else:
ckpt_save_dir += "ckpt_" + "/"
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
......
......@@ -21,7 +21,7 @@ import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32, run_distribute=False):
"""
create a train or eval dataset
......@@ -36,9 +36,12 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
"""
if device_target == "GPU":
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
if run_distribute:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
......@@ -56,7 +59,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
resize_op = C.Resize(256)
center_crop = C.CenterCrop(resize_width)
rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
change_swap_op = C.HWC2CHW()
if do_train:
......
......@@ -16,6 +16,7 @@
import time
import argparse
import ast
import numpy as np
from mindspore import context
......@@ -46,16 +47,18 @@ parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--device_target', type=str, default="GPU", help='run device_target')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
args_opt = parser.parse_args()
if args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU",
save_graphs=False)
init()
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
if args_opt.run_distribute:
init()
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
raise ValueError("Unsupported device_target.")
......@@ -168,7 +171,8 @@ if __name__ == '__main__':
config=config_gpu,
device_target=args_opt.device_target,
repeat_num=1,
batch_size=config_gpu.batch_size)
batch_size=config_gpu.batch_size,
run_distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size()
# resume
if args_opt.pre_trained:
......@@ -191,7 +195,10 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale)
cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if args_opt.run_distribute:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
else:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/"
if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
......
......@@ -399,6 +399,6 @@ class PredictWithSigmoid(nn.Cell):
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
logits, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册