提交 587561f7 编写于 作者: L lichenever

fix_distributed_training_doc bug

上级 3c7779ce
...@@ -256,11 +256,13 @@ device_id = int(os.getenv('DEVICE_ID')) ...@@ -256,11 +256,13 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id context.set_context(device_id=device_id) # set device_id
def test_train_cifar(num_classes=10, epoch_size=10): def test_train_cifar(epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
loss_cb = LossMonitor() loss_cb = LossMonitor()
dataset = create_dataset(epoch_size) dataset = create_dataset(data_path, epoch_size)
net = resnet50(32, num_classes) batch_size = 32
num_classes = 10
net = resnet50(batch_size, num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True) loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt) model = Model(net, loss_fn=loss, optimizer=opt)
...@@ -342,17 +344,14 @@ The running time is about 5 minutes, which is mainly occupied by operator compil ...@@ -342,17 +344,14 @@ The running time is about 5 minutes, which is mainly occupied by operator compil
Log files are saved in the device directory. The env.log file records environment variable information. The train.log file records the loss function information. The following is an example: Log files are saved in the device directory. The env.log file records environment variable information. The train.log file records the loss function information. The following is an example:
``` ```
resnet50_distributed_training.py::test_train_feed ===============ds_num 195 epoch: 1 step: 156, loss is 2.0084016
global_step: 194, loss: 1.997 epoch: 2 step: 156, loss is 1.6407638
global_step: 389, loss: 1.655 epoch: 3 step: 156, loss is 1.6164391
global_step: 584, loss: 1.723 epoch: 4 step: 156, loss is 1.6838071
global_step: 779, loss: 1.807 epoch: 5 step: 156, loss is 1.6320667
global_step: 974, loss: 1.417 epoch: 6 step: 156, loss is 1.3098773
global_step: 1169, loss: 1.195 epoch: 7 step: 156, loss is 1.3515002
global_step: 1364, loss: 1.238 epoch: 8 step: 156, loss is 1.2943741
global_step: 1559, loss: 1.456 epoch: 9 step: 156, loss is 1.2316195
global_step: 1754, loss: 0.987 epoch: 10 step: 156, loss is 1.1533381
global_step: 1949, loss: 1.035
end training
PASSED
``` ```
...@@ -254,11 +254,13 @@ device_id = int(os.getenv('DEVICE_ID')) ...@@ -254,11 +254,13 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id context.set_context(device_id=device_id) # set device_id
def test_train_cifar(num_classes=10, epoch_size=10): def test_train_cifar(epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
loss_cb = LossMonitor() loss_cb = LossMonitor()
dataset = create_dataset(epoch_size) dataset = create_dataset(data_path, epoch_size)
net = resnet50(32, num_classes) batch_size = 32
num_classes = 10
net = resnet50(batch_size, num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True) loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt) model = Model(net, loss_fn=loss, optimizer=opt)
...@@ -340,17 +342,14 @@ cd ../ ...@@ -340,17 +342,14 @@ cd ../
日志文件保存device目录下,env.log中记录了环境变量的相关信息,关于Loss部分结果保存在train.log中,示例如下: 日志文件保存device目录下,env.log中记录了环境变量的相关信息,关于Loss部分结果保存在train.log中,示例如下:
``` ```
resnet50_distributed_training.py::test_train_feed ===============ds_num 195 epoch: 1 step: 156, loss is 2.0084016
global_step: 194, loss: 1.997 epoch: 2 step: 156, loss is 1.6407638
global_step: 389, loss: 1.655 epoch: 3 step: 156, loss is 1.6164391
global_step: 584, loss: 1.723 epoch: 4 step: 156, loss is 1.6838071
global_step: 779, loss: 1.807 epoch: 5 step: 156, loss is 1.6320667
global_step: 974, loss: 1.417 epoch: 6 step: 156, loss is 1.3098773
global_step: 1169, loss: 1.195 epoch: 7 step: 156, loss is 1.3515002
global_step: 1364, loss: 1.238 epoch: 8 step: 156, loss is 1.2943741
global_step: 1559, loss: 1.456 epoch: 9 step: 156, loss is 1.2316195
global_step: 1754, loss: 0.987 epoch: 10 step: 156, loss is 1.1533381
global_step: 1949, loss: 1.035
end training
PASSED
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册