未验证 提交 e281b530 编写于 作者: D Divano 提交者: GitHub

Add muti cards support for se_resnext (#2787)

* Update mnist_dygraph.py

fix bug

* add muti card support for se_resnext

* add some description to readme.md
上级 e32bb3f5
...@@ -27,6 +27,22 @@ env CUDA_VISIBLE_DEVICES=0 python train.py ...@@ -27,6 +27,22 @@ env CUDA_VISIBLE_DEVICES=0 python train.py
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。 这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
亦可以使用多卡进行训练:
```
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py --use_data_parallel 1
```
这里`--selected_gpus=0,1,2,3`表示使用0,1,2,3号设备卡,共计4卡进行多卡训练,请根据自身情况修改这个参数。
此时,程序会将每个进程的输出log导入到`./mylog`路径下:
```
.
├── mylog
│ ├── workerlog.0
│ ├── workerlog.1
│ ├── workerlog.2
│ └── workerlog.3
├── README.md
└── train.py
```
## 输出 ## 输出
......
...@@ -26,10 +26,16 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -26,10 +26,16 @@ from paddle.fluid.dygraph.base import to_variable
import sys import sys
import math import math
import argparse import argparse
import ast
parser = argparse.ArgumentParser("Training for Se-ResNeXt.") parser = argparse.ArgumentParser("Training for Se-ResNeXt.")
parser.add_argument("-e", "--epoch", default=200, type=int, help="set epoch") parser.add_argument("-e", "--epoch", default=200, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce") parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
args = parser.parse_args() args = parser.parse_args()
batch_size = 64 batch_size = 64
train_parameters = { train_parameters = {
...@@ -361,22 +367,30 @@ def train(): ...@@ -361,22 +367,30 @@ def train():
epoch_num = args.epoch epoch_num = args.epoch
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
with fluid.dygraph.guard(): trainer_count = fluid.dygraph.parallel.Env().nranks
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.ce: if args.ce:
print("ce mode") print("ce mode")
seed = 90 seed = 90
np.random.seed(seed) np.random.seed(seed)
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
se_resnext = SeResNeXt("se_resnext") se_resnext = SeResNeXt("se_resnext")
optimizer = optimizer_setting(train_parameters) optimizer = optimizer_setting(train_parameters)
if args.use_data_parallel:
se_resnext = fluid.dygraph.parallel.DataParallel(se_resnext, strategy)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False), paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size, batch_size=batch_size,
drop_last=True drop_last=True
) )
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False), batch_size=32) paddle.dataset.flowers.test(use_xmap=False), batch_size=32)
...@@ -407,6 +421,11 @@ def train(): ...@@ -407,6 +421,11 @@ def train():
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5) acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
dy_out = avg_loss.numpy() dy_out = avg_loss.numpy()
if args.use_data_parallel:
avg_loss = se_resnext.scale_loss(avg_loss)
avg_loss.backward()
se_resnext.apply_collective_grads()
else:
avg_loss.backward() avg_loss.backward()
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
...@@ -418,7 +437,6 @@ def train(): ...@@ -418,7 +437,6 @@ def train():
total_acc5 += acc_top5.numpy() total_acc5 += acc_top5.numpy()
total_sample += 1 total_sample += 1
if batch_id % 10 == 0: if batch_id % 10 == 0:
print(fluid.dygraph.base._print_debug_msg())
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f lr %0.5f" % \ print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f lr %0.5f" % \
( epoch_id, batch_id, total_loss / total_sample, \ ( epoch_id, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample, lr)) total_acc1 / total_sample, total_acc5 / total_sample, lr))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册