未验证 提交 4ec35251 编写于 作者: M minghaoBD 提交者: GitHub

[Unstructured Pruner]add support for cpu training (#898)

上级 36a9f6f0
......@@ -22,6 +22,7 @@ _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use GPU for training or not. Default: True")
add_arg('batch_size', int, 64, "Minibatch size. Default: 64")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation. Default: 64")
add_arg('lr', float, 0.05, "The learning rate used to fine-tune pruned model. Default: 0.05")
......@@ -105,7 +106,10 @@ def create_unstructured_pruner(model, args, configs=None):
def compress(args):
place = paddle.set_device('gpu')
if args.use_gpu:
place = paddle.set_device('gpu')
else:
place = paddle.set_device('cpu')
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
......
#!/bin/bash
CUDA_VISIBLE_DEVICES='' python \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.55 \
--lr 0.05 \
--use_gpu False
......@@ -21,6 +21,7 @@ _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use gpu for traning or not. Defauly: True")
add_arg('batch_size', int, 64, "Minibatch size. Default: 64")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation. Default: 64")
add_arg('model', str, "MobileNet", "The target model.")
......@@ -137,7 +138,10 @@ def compress(args):
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
places = paddle.static.cuda_places()
if args.use_gpu:
places = paddle.static.cuda_places()
else:
places = paddle.static.cpu_places()
place = places[0]
exe = paddle.static.Executor(place)
......
CUDA_VISIBLE_DEVICES='' python train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.55 \
--lr 0.05 \
--model MobileNet \
--pretrained_model "MobileNetV1_pretrained" \
--use_gpu False \
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册