未验证 提交 e8d6f633 编写于 作者: K kangguangli 提交者: GitHub

replace with_data_parallel with fleet (#1626)

上级 8bf2df5b
......@@ -13,7 +13,7 @@
默认配置:
```yaml
batch_size: 256
batch_size: 64
init_lr: 0.1
lr_strategy: piecewise_decay
l2_decay: 3e-5
......@@ -21,14 +21,14 @@ momentum_rate: 0.9
num_epochs: 120
data: imagenet
```
训练使用默认配置启动即可
训练使用默认配置启动即可。这里的batch_size指每张卡上的batch_size。
### 2. 启动训练
在配置好ImageNet数据集后,用以下命令启动训练即可:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 python distill.py
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch distill.py
```
### 3. 训练结果
......
......@@ -15,6 +15,9 @@ import models
from utility import add_arguments, print_arguments, _download, _decompress
from paddleslim.dist import merge, l2, soft_label
from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
......@@ -76,6 +79,9 @@ def create_optimizer(args):
def compress(args):
fleet.init(is_collective=True)
if args.data == "cifar10":
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
......@@ -103,38 +109,38 @@ def compress(args):
else:
devices_num = int(os.environ.get('CPU_NUM', 1))
with paddle.static.program_guard(student_program, s_startup):
with paddle.utils.unique_name.guard():
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(
name='label', shape=[None, 1], dtype='int64')
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
feed_list=[image, label],
drop_last=True,
batch_size=int(args.batch_size / devices_num),
return_list=False,
shuffle=True,
use_shared_memory=True,
num_workers=4)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=True,
batch_size=args.batch_size,
shuffle=False)
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(
input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
sampler = paddle.io.DistributedBatchSampler(
train_dataset,
shuffle=False,
drop_last=True,
batch_size=args.batch_size)
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
feed_list=[image, label],
batch_sampler=sampler,
return_list=False,
use_shared_memory=False,
num_workers=4)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=False,
batch_size=args.batch_size,
shuffle=False)
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = student_program.clone(for_test=True)
exe = paddle.static.Executor(place)
......@@ -172,18 +178,19 @@ def compress(args):
data_name_map = {'image': 'image'}
merge(teacher_program, student_program, data_name_map, place)
build_strategy = paddle.static.BuildStrategy()
dist_strategy = DistributedStrategy()
dist_strategy.build_strategy = build_strategy
with paddle.static.program_guard(student_program, s_startup):
distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
student_program)
loss = avg_cost + distill_loss
lr, opt = create_optimizer(args)
opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
opt.minimize(loss)
exe.run(s_startup)
build_strategy = paddle.static.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
parallel_main = paddle.static.CompiledProgram(
student_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
parallel_main = student_program
for epoch_id in range(args.num_epochs):
for step_id, data in enumerate(train_loader):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册