提交 e69ef85d 编写于 作者: G gengdongjie

modify resnet50 for cloud train performance

上级 c232a04e
...@@ -27,6 +27,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -27,6 +27,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import Callback, LossMonitor from mindspore.train.callback import Callback, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from dataset import create_dataset, device_id, device_num from dataset import create_dataset, device_id, device_num
...@@ -121,6 +122,7 @@ def resnet50_train(args_opt): ...@@ -121,6 +122,7 @@ def resnet50_train(args_opt):
context.set_auto_parallel_context(device_num=device_num, context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) mirror_mean=True)
init()
local_data_path = os.path.join(local_data_path, str(device_id)) local_data_path = os.path.join(local_data_path, str(device_id))
# data download # data download
...@@ -138,12 +140,12 @@ def resnet50_train(args_opt): ...@@ -138,12 +140,12 @@ def resnet50_train(args_opt):
# create model # create model
net = resnet50(class_num = class_num) net = resnet50(class_num = class_num)
loss = SoftmaxCrossEntropyWithLogits(sparse=True) loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
lr = Tensor(get_lr(global_step=0, total_epochs=epoch_size, steps_per_epoch=train_step_size)) lr = Tensor(get_lr(global_step=0, total_epochs=epoch_size, steps_per_epoch=train_step_size))
opt = Momentum(net.trainable_params(), lr, momentum=0.9, weight_decay=1e-4, loss_scale=loss_scale_num) opt = Momentum(net.trainable_params(), lr, momentum=0.9, weight_decay=1e-4, loss_scale=loss_scale_num)
loss_scale = FixedLossScaleManager(loss_scale_num, False) loss_scale = FixedLossScaleManager(loss_scale_num, False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
# define performance callback to show ips and loss callback to show loss for every epoch # define performance callback to show ips and loss callback to show loss for every epoch
performance_cb = PerformanceCallback(batch_size) performance_cb = PerformanceCallback(batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册