未验证 提交 123fcc57 编写于 作者: M minghaoBD 提交者: GitHub

Fix api docs easeof use (#740) (#743)

上级 aef70340
......@@ -33,7 +33,7 @@ def compress(args):
test_reader = None
if args.data == "imagenet":
import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
elif args.data == "cifar10":
normalize = T.Normalize(
......@@ -47,13 +47,12 @@ def compress(args):
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
batch_size_per_card = int(args.batch_size / len(places))
valid_loader = paddle.io.DataLoader(
val_dataset,
places=places,
drop_last=False,
return_list=True,
batch_size=batch_size_per_card,
batch_size=args.batch_size,
shuffle=False,
use_shared_memory=True)
......@@ -70,15 +69,12 @@ def compress(args):
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
end_time = time.time()
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
acc_top1_ns.append(acc_top1.numpy())
acc_top5_ns.append(acc_top5.numpy())
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
......
......@@ -23,6 +23,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
......@@ -121,7 +122,7 @@ def compress(args):
places=place,
drop_last=False,
return_list=True,
batch_size=64,
batch_size=args.batch_size_for_validation,
shuffle=False,
use_shared_memory=True)
step_per_epoch = int(
......@@ -146,15 +147,12 @@ def compress(args):
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
end_time = time.time()
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
acc_top1_ns.append(acc_top1.numpy())
acc_top5_ns.append(acc_top5.numpy())
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
......
......@@ -20,7 +20,7 @@ _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64*12, "Minibatch size.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pruned_model', str, "models", "Whether to use pretrained model.")
......@@ -44,8 +44,8 @@ def compress(args):
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(data_dir='/data', mode='train')
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
......@@ -71,7 +71,6 @@ def compress(args):
use_shared_memory=True,
batch_size=batch_size_per_card,
shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
# model definition
model = models.__dict__[args.model]()
......@@ -103,12 +102,7 @@ def compress(args):
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed={
"image": data[0].get('image'),
"label": data[0].get('label'),
},
fetch_list=[acc_top1.name, acc_top5.name])
program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
......
......@@ -20,6 +20,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.")
......@@ -123,7 +124,7 @@ def compress(args):
drop_last=False,
return_list=False,
use_shared_memory=True,
batch_size=batch_size_per_card,
batch_size=args.batch_size_for_validation,
shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
......
......@@ -7,7 +7,7 @@ UnstructuredPruner
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.3, skip_params_func=None)
`源代码 <https://github.com/minghaoBD/PaddleSlim/blob/update_unstructured_pruning_docs/paddleslim/dygraph/prune/unstructured_pruner.py>`_
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
对于神经网络中的参数进行非结构化稀疏。非结构化稀疏是指,根据某些衡量指标,将不重要的参数置0。其不按照固定结构剪裁(例如一个通道等),这是和结构化剪枝的主要区别。
......@@ -23,11 +23,16 @@ UnstructuredPruner
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
import paddle
from paddleslim import UnstructuredPruner
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
..
......@@ -38,13 +43,19 @@ UnstructuredPruner
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
print(pruner.threshold)
pruner.step()
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
..
......@@ -54,13 +65,23 @@ UnstructuredPruner
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='threshold', threshold=0.5)
density = UnstructuredPruner.total_sparse(model)
print(density)
model(paddle.to_tensor(
np.random.uniform(0, 1, [16, 1, 28, 28]), dtype='float32'))
pruner.update_params()
density = UnstructuredPruner.total_sparse(model)
print(density) # 可以看出,这里打印的模型稠密度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.5的权重。
..
......@@ -78,13 +99,17 @@ UnstructuredPruner
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
density = UnstructuredPruner.total_sparse(model)
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
density = UnstructuredPruner.total_sparse(model)
print(density)
..
.. py:method:: paddleslim.UnstructuredPruner.summarize_weights(model, ratio=0.1)
......@@ -102,12 +127,17 @@ UnstructuredPruner
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
threshold = pruner.summarize_weights(model, ratio=0.1)
threshold = pruner.summarize_weights(model, 0.5)
print(threshold)
..
......@@ -24,13 +24,30 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner()
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step()
......@@ -39,33 +56,71 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
pruner.step()
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
print(pruner.threshold)
pruner.step()
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params()
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。但是,在训练过程中,由于step()函数会调用该方法,故不需要开发者在训练过程中额外调用了。
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'threshold', threshold=0.5, place=place)
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density)
pruner.step()
pruner.update_params()
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density) # 可以看出,这里打印的模型稠密度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.5的权重。
..
......@@ -83,13 +138,31 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density)
..
......@@ -108,14 +181,31 @@ UnstrucuturedPruner
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
import paddle as paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
threshold = pruner.summarize_weights(paddle.static.default_main_program(), 1.0)
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
threshold = pruner.summarize_weights(paddle.static.default_main_program(), ratio=0.5)
print(threshold)
..
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册