未验证 提交 e8d6f633 编写于 作者: K kangguangli 提交者: GitHub

replace with_data_parallel with fleet (#1626)

上级 8bf2df5b
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
默认配置: 默认配置:
```yaml ```yaml
batch_size: 256 batch_size: 64
init_lr: 0.1 init_lr: 0.1
lr_strategy: piecewise_decay lr_strategy: piecewise_decay
l2_decay: 3e-5 l2_decay: 3e-5
...@@ -21,14 +21,14 @@ momentum_rate: 0.9 ...@@ -21,14 +21,14 @@ momentum_rate: 0.9
num_epochs: 120 num_epochs: 120
data: imagenet data: imagenet
``` ```
训练使用默认配置启动即可 训练使用默认配置启动即可。这里的batch_size指每张卡上的batch_size。
### 2. 启动训练 ### 2. 启动训练
在配置好ImageNet数据集后,用以下命令启动训练即可: 在配置好ImageNet数据集后,用以下命令启动训练即可:
```shell ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 python distill.py CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch distill.py
``` ```
### 3. 训练结果 ### 3. 训练结果
......
...@@ -15,6 +15,9 @@ import models ...@@ -15,6 +15,9 @@ import models
from utility import add_arguments, print_arguments, _download, _decompress from utility import add_arguments, print_arguments, _download, _decompress
from paddleslim.dist import merge, l2, soft_label from paddleslim.dist import merge, l2, soft_label
from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO) _logger.setLevel(logging.INFO)
...@@ -76,6 +79,9 @@ def create_optimizer(args): ...@@ -76,6 +79,9 @@ def create_optimizer(args):
def compress(args): def compress(args):
fleet.init(is_collective=True)
if args.data == "cifar10": if args.data == "cifar10":
train_dataset = paddle.vision.datasets.Cifar10(mode='train') train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test') val_dataset = paddle.vision.datasets.Cifar10(mode='test')
...@@ -103,38 +109,38 @@ def compress(args): ...@@ -103,38 +109,38 @@ def compress(args):
else: else:
devices_num = int(os.environ.get('CPU_NUM', 1)) devices_num = int(os.environ.get('CPU_NUM', 1))
with paddle.static.program_guard(student_program, s_startup): with paddle.static.program_guard(student_program, s_startup):
with paddle.utils.unique_name.guard(): image = paddle.static.data(
image = paddle.static.data( name='image', shape=[None] + image_shape, dtype='float32')
name='image', shape=[None] + image_shape, dtype='float32') label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
label = paddle.static.data( sampler = paddle.io.DistributedBatchSampler(
name='label', shape=[None, 1], dtype='int64') train_dataset,
train_loader = paddle.io.DataLoader( shuffle=False,
train_dataset, drop_last=True,
places=places, batch_size=args.batch_size)
feed_list=[image, label], train_loader = paddle.io.DataLoader(
drop_last=True, train_dataset,
batch_size=int(args.batch_size / devices_num), places=places,
return_list=False, feed_list=[image, label],
shuffle=True, batch_sampler=sampler,
use_shared_memory=True, return_list=False,
num_workers=4) use_shared_memory=False,
valid_loader = paddle.io.DataLoader( num_workers=4)
val_dataset, valid_loader = paddle.io.DataLoader(
places=place, val_dataset,
feed_list=[image, label], places=place,
drop_last=False, feed_list=[image, label],
return_list=False, drop_last=False,
use_shared_memory=True, return_list=False,
batch_size=args.batch_size, use_shared_memory=False,
shuffle=False) batch_size=args.batch_size,
# model definition shuffle=False)
model = models.__dict__[args.model]() # model definition
out = model.net(input=image, class_dim=class_dim) model = models.__dict__[args.model]()
cost = paddle.nn.functional.loss.cross_entropy( out = model.net(input=image, class_dim=class_dim)
input=out, label=label) cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost) avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = student_program.clone(for_test=True) val_program = student_program.clone(for_test=True)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
...@@ -172,18 +178,19 @@ def compress(args): ...@@ -172,18 +178,19 @@ def compress(args):
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
build_strategy = paddle.static.BuildStrategy()
dist_strategy = DistributedStrategy()
dist_strategy.build_strategy = build_strategy
with paddle.static.program_guard(student_program, s_startup): with paddle.static.program_guard(student_program, s_startup):
distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0", distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
student_program) student_program)
loss = avg_cost + distill_loss loss = avg_cost + distill_loss
lr, opt = create_optimizer(args) lr, opt = create_optimizer(args)
opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
opt.minimize(loss) opt.minimize(loss)
exe.run(s_startup) exe.run(s_startup)
build_strategy = paddle.static.BuildStrategy() parallel_main = student_program
build_strategy.fuse_all_reduce_ops = False
parallel_main = paddle.static.CompiledProgram(
student_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
for epoch_id in range(args.num_epochs): for epoch_id in range(args.num_epochs):
for step_id, data in enumerate(train_loader): for step_id, data in enumerate(train_loader):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册