未验证 提交 e5fc68b2 编写于 作者: M Ming-Xu Huang 提交者: GitHub

Dynamic graph support to Automatic SParsity. (#41177)

* Dynamic graph support to Automatic SParsity.

1. Added dynamic support to ASP module (paddle.fluid.contrib.sparsity).
2. Added ASP related unit-tests regards to above changes.
3. Put ASP module under paddle.static for now, waiting for APIs confirmation from Paddle.

* Modified documents of functions to have correct examples.

* Update in_dygraph_mode to paddle.in_dynamic_mode()

* Modified documents of functions and added comments

* Minor changes.

* Fix example errors in asp API.

* Code Change for Review

1. Added more examples in documents.
2. Chaged test_asp_pruning_static.

* Minor changes

* Update ASP function documents.

* Update ASP function documents.

* Reduce test case size of asp pruning due CI time limit.

* Update time limitation to some asp UTs.

* Fix sample code errors.

* Fix sample code errors.

* Fix sample code errors.

* Update time limitation to parts of ASP UTs.

* Update UTs to fit with CI.

* Reduce problem size in python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py

* Added paddle.asp

* Fixed type casting error of OpRole.Optimize in new dygraph mode.

* Made set_excluded_layers be compatible with 2.2

* Fix example code of calculate_density.

* Update code examples.

* Move paddle.asp to paddle.incubate.asp

* Fixed an example error of calculate_density
上级 4218957b
...@@ -20,12 +20,13 @@ import os ...@@ -20,12 +20,13 @@ import os
import copy import copy
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid.framework import dygraph_only
from paddle.fluid import global_scope, program_guard, layers from paddle.fluid import global_scope, program_guard, layers
from paddle.fluid.initializer import ConstantInitializer from paddle.fluid.initializer import ConstantInitializer
from paddle.fluid.contrib import sparsity from paddle.fluid.contrib import sparsity
from paddle.fluid import core
from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map
from paddle.fluid.contrib.sparsity.supported_layer_list import _default_pruning from paddle.fluid.contrib.sparsity.supported_layer_list import _default_pruning
from paddle.fluid import core
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
...@@ -35,45 +36,90 @@ __all__ = [ ...@@ -35,45 +36,90 @@ __all__ = [
] ]
def set_excluded_layers(main_program, param_names): def set_excluded_layers(param_names, main_program=None):
r""" r"""
Set parameter name of layers which would not be pruned as sparse weights. Set parameter name of layers which would not be pruned as sparse weights.
Args: Args:
param_names (list of string): A list contains names of parameters.
main_program (Program, optional): Program with model definition and its parameters. main_program (Program, optional): Program with model definition and its parameters.
param_names (list): A list contains names of parameters. If None is given, then it would be set as `paddle.static.default_main_program().
Default is None.
Examples: Examples:
.. code-block:: python 1. Usage of Dynamic Graph
import paddle .. code-block:: python
from paddle.static import sparsity
import paddle
paddle.enable_static()
class MyLayer(paddle.nn.Layer):
main_program = paddle.static.Program() def __init__(self):
startup_program = paddle.static.Program() super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
with paddle.static.program_guard(main_program, startup_program): in_channels=3, out_channels=4, kernel_size=3, padding=2)
input_data = paddle.static.data(name='data', shape=[None, 128]) self.linear1 = paddle.nn.Linear(4624, 100)
label = paddle.static.data(name='label', shape=[None, 10])
hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc") def forward(self, img):
hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc") hidden = self.conv1(img)
prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) hidden = paddle.flatten(hidden, start_axis=1)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) prediction = self.linear1(hidden)
return prediction
# Setup exluded layers out from ASP workflow.
# Please note, excluded_layers must be set before calling `optimizer.minimize()`. my_layer = MyLayer()
sparsity.set_excluded_layers(main_program, ["need_dense_fc"]) optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=my_layer.parameters())
optimizer = paddle.optimizer.SGD(learning_rate=0.1)
optimizer = paddle.static.amp.decorate(optimizer ) # Need to set excluded layers before calling decorate
# Calling sparsity.decorate() to wrap minimize() in optimizer, which paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()])
# will insert necessary masking operations for ASP workflow.
optimizer = sparsity.decorate(optimizer) optimizer = paddle.incubate.asp.decorate(optimizer)
optimizer.minimize(loss, startup_program)
2. Usage of Static Graph
.. code-block:: python
import paddle
paddle.enable_static()
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 100)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
prediction = self.linear1(hidden)
return prediction
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
label = paddle.static.data(name='label', shape=[None, 100])
my_layer = MyLayer()
prob = my_layer(input_data)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
# Setup exluded layers out from ASP workflow.
# Please note, excluded_layers must be set before calling optimizer.minimize().
paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program)
optimizer = paddle.optimizer.SGD(learning_rate=0.1)
optimizer = paddle.static.amp.decorate(optimizer )
# Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
# will insert necessary masking operations for ASP workflow.
optimizer = paddle.incubate.asp.decorate(optimizer)
optimizer.minimize(loss, startup_program)
""" """
if main_program is None:
main_program = paddle.static.default_main_program()
ASPHelper.set_excluded_layers( ASPHelper.set_excluded_layers(
main_program=main_program, param_names=param_names) param_names=param_names, main_program=main_program)
def reset_excluded_layers(main_program=None): def reset_excluded_layers(main_program=None):
...@@ -83,153 +129,310 @@ def reset_excluded_layers(main_program=None): ...@@ -83,153 +129,310 @@ def reset_excluded_layers(main_program=None):
Args: Args:
main_program (Program, optional): Program with model definition and its parameters. main_program (Program, optional): Program with model definition and its parameters.
Examples: If None is given, then this function would reset all excluded_layers.
.. code-block:: python Default is None.
Examples:
1. Usage of Dynamic Graph
import paddle .. code-block:: python
from paddle.static import sparsity
paddle.enable_static() import paddle
main_program = paddle.static.Program() class MyLayer(paddle.nn.Layer):
startup_program = paddle.static.Program() def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 100)
with paddle.static.program_guard(main_program, startup_program): def forward(self, img):
input_data = paddle.static.data(name='data', shape=[None, 128]) hidden = self.conv1(img)
label = paddle.static.data(name='label', shape=[None, 10]) hidden = paddle.flatten(hidden, start_axis=1)
hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="my_first_fc") prediction = self.linear1(hidden)
hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="my_second_fc") return prediction
prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
# Setup exluded layers out from ASP workflow. my_layer = MyLayer()
# Please note, excluded_layers must be set before calling `optimizer.minimize()`. optimizer = paddle.optimizer.SGD(
sparsity.set_excluded_layers(main_program, ["my_second_fc"]) learning_rate=0.01, parameters=my_layer.parameters())
# Now the weights of "my_second_fc" would not be included in Automatic SParsity's workflow.
# Need to set excluded layers before calling decorate
paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()])
# Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow.
# Please note, reset_excluded_layers also must be called before calling sparsity.decorate().
paddle.incubate.asp.reset_excluded_layers()
optimizer = paddle.incubate.asp.decorate(optimizer)
2. Usage of Static Graph
.. code-block:: python
# Reset excluded_layers, all FC layers would be included into Automatic SParsity's workflow. import paddle
# Please note, reset_excluded_layers also must be called before calling `optimizer.minimize()`.
sparsity.reset_excluded_layers(main_program) paddle.enable_static()
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 100)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
prediction = self.linear1(hidden)
return prediction
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
label = paddle.static.data(name='label', shape=[None, 100])
my_layer = MyLayer()
prob = my_layer(input_data)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
# Setup exluded layers out from ASP workflow.
# Please note, excluded_layers must be set before calling optimizer.minimize().
paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program)
# Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow.
# Please note, reset_excluded_layers also must be called before calling optimizer.minimize().
paddle.incubate.asp.reset_excluded_layers(main_program)
optimizer = paddle.optimizer.SGD(learning_rate=0.1)
optimizer = paddle.static.amp.decorate(optimizer )
# Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
# will insert necessary masking operations for ASP workflow.
optimizer = paddle.incubate.asp.decorate(optimizer)
optimizer.minimize(loss, startup_program)
""" """
ASPHelper.reset_excluded_layers(main_program=main_program) ASPHelper.reset_excluded_layers(main_program=main_program)
def decorate(optimizer): def decorate(optimizer):
r""" r"""
Wrap the given optimizer as a OptimizerWithSparsityGuarantee, Wrap the given optimizer as a OptimizerWithSparsityGuarantee,
which would insert necessary ops for ASP workflows when calling minimize() If runnig with dynamic graph mode. ASP would creates mask variables for supported parameters.
Else if in static graph mode, ASP would creates mask variables and inserts necessary ops
when calling minimize()
Args: Args:
optimizer (Optimizer): A Optimizer used for training. optimizer (Optimizer): A Optimizer used for training.
Returns: Returns:
OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer. OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer.
Examples: Examples:
.. code-block:: python 1. Usage of Dynamic Graph
import paddle .. code-block:: python
from paddle.static import sparsity
import paddle
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 32)
self.linear2 = paddle.nn.Linear(32, 32)
self.linear3 = paddle.nn.Linear(32, 10)
main_program = paddle.static.Program() def forward(self, img):
startup_program = paddle.static.Program() hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
hidden = self.linear1(hidden)
hidden = self.linear2(hidden)
prediction = self.linear3(hidden)
return prediction
paddle.enable_static() my_layer = MyLayer()
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=my_layer.parameters())
with paddle.static.program_guard(main_program, startup_program): # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which
input_data = paddle.static.data(name='data', shape=[None, 128]) # will apply necessary masking operations for ASP workflow.
label = paddle.static.data(name='label', shape=[None, 10]) # In dynamic graph mode, ASP would create related mask variables during decoration.
hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None) optimizer = paddle.incubate.asp.decorate(optimizer)
prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
optimizer = paddle.optimizer.SGD(learning_rate=0.1) 2. Usage of Static Graph
optimizer = sparsity.decorate(optimizer)
# if do sparse training with Fleet, please replace above decorate with:
# strategy = paddle.distributed.fleet.DistributedStrategy()
# strategy.asp = True
# optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss, startup_program) .. code-block:: python
import paddle
paddle.enable_static()
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 100)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
prediction = self.linear1(hidden)
return prediction
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
label = paddle.static.data(name='label', shape=[None, 100])
my_layer = MyLayer()
prob = my_layer(input_data)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
optimizer = paddle.optimizer.SGD(learning_rate=0.1)
# Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
# will insert necessary masking operations for ASP workflow.
# In static graph mode, ASP creates related mask variables
# during minimize().
optimizer = paddle.incubate.asp.decorate(optimizer)
optimizer.minimize(loss, startup_program)
""" """
return ASPHelper.decorate(optimizer) return ASPHelper.decorate(optimizer)
def prune_model(main_program=None, def prune_model(model, n=2, m=4, mask_algo='mask_1d', with_mask=True):
n=2,
m=4,
mask_algo='mask_1d',
with_mask=True):
r""" r"""
Pruning parameters of supported layers in :attr:`main_program` via Pruning parameters of supported layers in :attr:`model` via
specified mask generation function given by :attr:`mask_algo`. This specified mask generation function given by :attr:`mask_algo`. This
function supports both training and inference controlled by :attr:`with_mask`. function supports both training and inference controlled by :attr:`with_mask`.
If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables, If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables,
else only prunes parameters. else only prunes parameters.
*Note*: If parameters are supported and in FP16, please set :attr:`n`=2, :attr:`m`=4, *Note*: (Static graph mode) If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize`
if they in FP32, then :attr:`n`=1, :attr:`m`=2` to further enable Sparse Tensor Core acceleration.
*Note*: If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize`
and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable). and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable).
Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for
inference only. To obtain OptimizerWithSparsityGuarantee, please see `sparsity.decoreate()`. inference only. To obtain OptimizerWithSparsityGuarantee, please see `paddle.incubate.asp.decoreate()`.
Args: Args:
main_program (Program, optional): Program with model definition and its parameters. Default is `paddle.static.default_main_program() model (Program|nn.Layer): Program with model definition and its parameters, or a object of `paddle.nn.Layer`.
n (int): n of `n:m` sparse pattern. n (int, optional): n of `n:m` sparse pattern. Default is 2.
m (int): m of `n:m` sparse pattern. m (int, optional): m of `n:m` sparse pattern. Default is 4.
mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`. mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`.
The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'. The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'.
with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True. with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True.
Returns: Returns:
dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable. dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable.
Examples: Examples:
.. code-block:: python 1. Usage of Dynamic Graph
import paddle
from paddle.static import sparsity
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
input_data = paddle.static.data(name='data', shape=[None, 128])
label = paddle.static.data(name='label', shape=[None, 10])
hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc")
hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc")
prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
# Setup exluded layers out from ASP workflow.
# Please note, excluded_layers must be set before calling `optimizer.minimize()`.
sparsity.set_excluded_layers(main_program, ["need_dense_fc"])
optimizer = paddle.optimizer.SGD(learning_rate=0.1) .. code-block:: python
optimizer = paddle.static.amp.decorate(optimizer )
# Calling sparsity.decorate() to wrap minimize() in optimizer, which
# will insert necessary masking operations for ASP workflow.
optimizer = sparsity.decorate(optimizer)
optimizer.minimize(loss, startup_program)
device = paddle.device.get_device() import paddle
place = paddle.set_device(device) import numpy as np
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 32)
self.linear2 = paddle.nn.Linear(32, 32)
self.linear3 = paddle.nn.Linear(32, 10)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
hidden = self.linear1(hidden)
hidden = self.linear2(hidden)
prediction = self.linear3(hidden)
return prediction
my_layer = MyLayer()
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=my_layer.parameters())
# Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which
# will apply necessary masking operations for ASP workflow.
# In dynamic graph mode, ASP would create related mask variables during decoration.
optimizer = paddle.incubate.asp.decorate(optimizer)
# Must call paddle.incubate.asp.decorate() first before calling paddle.incubate.asp.prune_model()
paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best')
for i in range(10):
imgs = paddle.to_tensor(
np.random.randn(64, 3, 32, 32),
dtype='float32', stop_gradient=False)
labels = paddle.to_tensor(
np.random.randint(10, size=(64, 1)),
dtype='float32', stop_gradient=False)
output = my_layer(imgs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
optimizer.clear_grad()
2. Usage of Static Graph
exe = paddle.static.Executor(place) .. code-block:: python
exe.run(startup_program)
# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model` import paddle
sparsity.prune_model(main_program, mask_algo='mask_2d_best') import numpy as np
paddle.enable_static()
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 32)
self.linear2 = paddle.nn.Linear(32, 32)
self.linear3 = paddle.nn.Linear(32, 10)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
hidden = self.linear1(hidden)
hidden = self.linear2(hidden)
prediction = self.linear3(hidden)
return prediction
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
input_data = paddle.static.data(name='data', shape=[None, 3, 32, 32])
label = paddle.static.data(name='label', shape=[None, 1])
my_layer = MyLayer()
prob = my_layer(input_data)
loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))
optimizer = paddle.optimizer.SGD(learning_rate=0.1)
# Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
# will insert necessary masking operations for ASP workflow.
# In static graph mode, ASP creates related mask variables
# during minimize().
optimizer = paddle.incubate.asp.decorate(optimizer)
optimizer.minimize(loss, startup_program)
device = paddle.device.get_device()
place = paddle.set_device(device)
exe = paddle.static.Executor(place)
exe.run(startup_program)
# Must call exe.run(startup_program) first before calling paddle.asp.prune_model()
paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best')
# it also be accepted to call
# paddle.incubate.asp.prune_model(main_program, mask_algo='mask_2d_best')
for i in range(10):
imgs = np.random.randn(64, 3, 32, 32).astype('float32')
labels = np.random.randint(10, size=(64, 1)).astype('float32')
exe.run(main_program, feed={'data':imgs, 'label':labels})
""" """
if main_program is not None and hasattr( device = paddle.device.get_device()
main_program, place = paddle.set_device(device)
"distributed_info_") and main_program.distributed_info_[
"sharding_degree"] > 1 and paddle.fluid.is_compiled_with_cuda():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id)
else:
device = paddle.device.get_device()
place = paddle.set_device(device)
MaskAlgo_mapping = { MaskAlgo_mapping = {
'mask_1d': sparsity.MaskAlgo.MASK_1D, 'mask_1d': sparsity.MaskAlgo.MASK_1D,
...@@ -237,11 +440,26 @@ def prune_model(main_program=None, ...@@ -237,11 +440,26 @@ def prune_model(main_program=None,
'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST 'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST
} }
assert (mask_algo in MaskAlgo_mapping), \ assert (mask_algo in MaskAlgo_mapping), \
'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]' 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]'
prune_func = None
if isinstance(model, paddle.nn.Layer):
prune_func = ASPHelper.prune_model_by_layer
elif isinstance(model, paddle.static.Program):
prune_func = ASPHelper.prune_model_by_program
if hasattr(model, "distributed_info_") and \
model.distributed_info_["sharding_degree"] > 1 and \
paddle.fluid.is_compiled_with_cuda():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id)
else:
raise TypeError(
"model should be paddle.nn.Layer or paddle.static.Program, but got {}".
format(type(model)))
return ASPHelper.prune_model( return prune_func(
place=place, place,
main_program=main_program, model,
n=n, n=n,
m=m, m=m,
mask_algo=MaskAlgo_mapping[mask_algo], mask_algo=MaskAlgo_mapping[mask_algo],
...@@ -300,7 +518,7 @@ class ASPHelper(object): ...@@ -300,7 +518,7 @@ class ASPHelper(object):
__asp_info = {} __asp_info = {}
@classmethod @classmethod
def set_excluded_layers(cls, main_program, param_names): def set_excluded_layers(cls, param_names, main_program):
r""" r"""
This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`. This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`.
""" """
...@@ -313,8 +531,8 @@ class ASPHelper(object): ...@@ -313,8 +531,8 @@ class ASPHelper(object):
This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`. This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`.
""" """
if main_program is None: if main_program is None:
for asp_info in cls.__asp_info: for prog in cls.__asp_info:
asp_info.reset_excluded_layers() cls.__asp_info[prog].reset_excluded_layers()
else: else:
cls._get_program_asp_info(main_program).reset_excluded_layers() cls._get_program_asp_info(main_program).reset_excluded_layers()
...@@ -323,16 +541,25 @@ class ASPHelper(object): ...@@ -323,16 +541,25 @@ class ASPHelper(object):
r""" r"""
This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`. This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`.
""" """
if paddle.in_dynamic_mode():
# main_prog and startup_prog would be used with paddle.static.program_guard
# to create ASP masks. Moreover, main_prog is a key to map paddle.static.Program
# to its own ASP informantion, like ASP mask variables. For dynamic graph, we use
# default_main_program as the key.
main_prog = paddle.static.default_main_program()
startup_prog = paddle.static.default_startup_program()
ASPHelper._create_mask_variables(main_prog, startup_prog,
optimizer._parameter_list)
return OptimizerWithSparsityGuarantee(optimizer) return OptimizerWithSparsityGuarantee(optimizer)
@classmethod @classmethod
def prune_model(cls, def prune_model_by_program(cls,
place, place,
main_program=None, main_program=None,
n=2, n=2,
m=4, m=4,
mask_algo=sparsity.MaskAlgo.MASK_1D, mask_algo=sparsity.MaskAlgo.MASK_1D,
with_mask=True): with_mask=True):
r""" r"""
This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
""" """
...@@ -366,9 +593,63 @@ class ASPHelper(object): ...@@ -366,9 +593,63 @@ class ASPHelper(object):
np.array(weight_mask_tensor).dtype) np.array(weight_mask_tensor).dtype)
weight_mask_tensor.set(weight_sparse_mask, place) weight_mask_tensor.set(weight_sparse_mask, place)
asp_info.update_masks(param.name, weight_sparse_mask) asp_info.update_masks(param.name, weight_sparse_mask)
return asp_info.masks.copy() return asp_info.masks.copy()
@classmethod
def prune_model_by_layer(cls,
place,
layer,
n=2,
m=4,
mask_algo=sparsity.MaskAlgo.MASK_1D,
with_mask=True):
r"""
This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
"""
if paddle.in_dynamic_mode():
main_program = paddle.static.default_main_program()
asp_info = cls._get_program_asp_info(main_program)
for param in layer.parameters():
if ASPHelper._is_supported_layer(main_program, param.name):
weight_nparray = param.numpy()
prune_func = ASPHelper._get_prune_func_by_name(param.name)
weight_pruned_nparray, weight_sparse_mask = \
prune_func(weight_nparray, m, n, mask_algo, param.name)
weight_pruned_nparray = weight_pruned_nparray.astype(
weight_nparray.dtype)
param.set_value(weight_pruned_nparray)
if with_mask:
weight_mask_param = asp_info.mask_vars.get(param.name,
None)
assert weight_mask_param is not None, \
'Cannot find {} variable, please call sparsity.decorate() to' \
' decorate your optimizer first!'.format(ASPHelper._get_mask_name(param.name))
weight_mask_param.set_value(weight_sparse_mask)
asp_info.update_masks(param.name, weight_sparse_mask)
return asp_info.masks.copy()
else:
# This for loop is only used to obtain Block and Program from
# first parameters.
target_program = None
for param in layer.parameters():
target_program = param.block.program
assert target_program is not None, \
'Cannot get paddle.static.Program from Paddle.nn.Layer.'
return ASPHelper.prune_model_by_program(
place,
target_program,
n=n,
m=m,
mask_algo=mask_algo,
with_mask=with_mask)
@staticmethod @staticmethod
def _get_mask_name(param_name): def _get_mask_name(param_name):
r""" r"""
...@@ -393,13 +674,15 @@ class ASPHelper(object): ...@@ -393,13 +674,15 @@ class ASPHelper(object):
""" """
var_list = [] var_list = []
for param in main_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
if ASPHelper.MASK_APPENDDED_NAME not in param.name: param_name_list = param.name.split('.')
if ASPHelper.MASK_APPENDDED_NAME not in param_name_list:
var_list.append(param) var_list.append(param)
return var_list return var_list
@classmethod @classmethod
def _get_program_asp_info(cls, main_program): def _get_program_asp_info(cls, main_program):
if not main_program in cls.__asp_info: if main_program not in cls.__asp_info:
cls.__asp_info[main_program] = ProgramASPInfo() cls.__asp_info[main_program] = ProgramASPInfo()
return cls.__asp_info[main_program] return cls.__asp_info[main_program]
...@@ -508,14 +791,37 @@ class ASPHelper(object): ...@@ -508,14 +791,37 @@ class ASPHelper(object):
optimizer_ops, params_and_grads = optimizer.minimize( optimizer_ops, params_and_grads = optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set=no_grad_set) loss, startup_program, parameter_list, no_grad_set=no_grad_set)
cls._create_mask_variables(main_program, startup_program,
params_and_grads) params_only = [pg[0] for pg in params_and_grads]
cls._insert_sparse_mask_ops(main_program, params_and_grads) cls._create_mask_variables(main_program, startup_program, params_only)
cls._insert_sparse_mask_ops(main_program, params_only)
return optimizer_ops, params_and_grads return optimizer_ops, params_and_grads
@classmethod @classmethod
def _create_mask_variables(cls, main_program, startup_program, @dygraph_only
params_and_grads): def _step(cls, optimizer):
r"""
This function is a decorator of `step` function in `Optimizer`.
There are three steps:
1. Call :attr:`optimizer`.step()
2. Mask parameters with sparse masks.
*Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`.
(Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph
cannot be modified anymore.)
Args:
optimizer (Optimizer): A Optimizer used for training.
"""
optimizer.step()
main_prog = paddle.static.default_main_program()
with paddle.fluid.dygraph.no_grad():
ASPHelper._insert_sparse_mask_ops(main_prog,
optimizer._parameter_list)
@classmethod
def _create_mask_variables(cls, main_program, startup_program, params):
r""" r"""
Create sparse mask Tensors according to supported layers in :attr:`main_program`. Create sparse mask Tensors according to supported layers in :attr:`main_program`.
This function is called in second step of `ASPHelper._minimize` This function is called in second step of `ASPHelper._minimize`
...@@ -523,48 +829,45 @@ class ASPHelper(object): ...@@ -523,48 +829,45 @@ class ASPHelper(object):
Args: Args:
main_program (Program): Program with model definition and its parameters. main_program (Program): Program with model definition and its parameters.
startup_program (Program): Program for initializing parameters. startup_program (Program): Program for initializing parameters.
params_and_grads (list): Variable pairs of parameters and their gradients. params (list): Variable parameters.
""" """
asp_info = cls._get_program_asp_info(main_program) asp_info = cls._get_program_asp_info(main_program)
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
for param_and_grad in params_and_grads: for param in params:
if ASPHelper._is_supported_layer(main_program, if ASPHelper._is_supported_layer(main_program, param.name):
param_and_grad[0].name): if param.name not in asp_info.mask_vars:
mask_param = layers.create_parameter( mask_param = layers.create_parameter(
name=ASPHelper._get_mask_name(param_and_grad[0].name), name=ASPHelper._get_mask_name(param.name),
shape=param_and_grad[0].shape, shape=param.shape,
dtype=param_and_grad[0].dtype, dtype=param.dtype,
default_initializer=ConstantInitializer(value=1.0)) default_initializer=ConstantInitializer(value=1.0))
mask_param.stop_gradient = True mask_param.stop_gradient = True
mask_param.trainable = False mask_param.trainable = False
asp_info.update_mask_vars(param_and_grad[0].name, asp_info.update_mask_vars(param.name, mask_param)
mask_param)
@classmethod @classmethod
def _insert_sparse_mask_ops(cls, main_program, param_grads): def _insert_sparse_mask_ops(cls, main_program, params):
r""" r"""
Insert masking ops in the end of parameters update. Insert masking ops in the end of parameters update.
This function is called in third step of `ASPHelper._minimize` This function is called in third step of `ASPHelper._minimize`
Args: Args:
main_program (Program): Program with model definition and its parameters. main_program (Program): Program with model definition and its parameters.
params_and_grads (list): Variable pairs of parameters and their gradients. params (list): Variable parameters.
""" """
block = main_program.global_block() block = main_program.global_block()
asp_info = cls._get_program_asp_info(main_program) asp_info = cls._get_program_asp_info(main_program)
for param_grad in param_grads: for param in params:
if param_grad[0].name in asp_info.mask_vars: if param.name in asp_info.mask_vars:
block.append_op( block.append_op(
type='elementwise_mul', type='elementwise_mul',
inputs={ inputs={"X": param,
"X": param_grad[0], 'Y': asp_info.mask_vars[param.name]},
'Y': asp_info.mask_vars[param_grad[0].name] outputs={'Out': param},
},
outputs={'Out': param_grad[0]},
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False, 'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: int(OpRole.Optimize)
}) })
...@@ -579,8 +882,9 @@ class OptimizerWithSparsityGuarantee(object): ...@@ -579,8 +882,9 @@ class OptimizerWithSparsityGuarantee(object):
def __init__(self, optimizer): def __init__(self, optimizer):
self._optimizer = optimizer self._optimizer = optimizer
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map def __getattr__(self, item):
return getattr(self._optimizer, item)
def minimize(self, def minimize(self,
loss, loss,
...@@ -605,3 +909,55 @@ class OptimizerWithSparsityGuarantee(object): ...@@ -605,3 +909,55 @@ class OptimizerWithSparsityGuarantee(object):
startup_program=startup_program, startup_program=startup_program,
parameter_list=parameter_list, parameter_list=parameter_list,
no_grad_set=no_grad_set) no_grad_set=no_grad_set)
@dygraph_only
def step(self):
r"""
This function is a decorator of `step` function in `Optimizer`.
There are three steps:
1. Call :attr:`optimizer`.step()
2. Mask parameters with sparse masks.
*Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`.
(Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph
cannot be modified anymore.)
Args:
optimizer (Optimizer): A Optimizer used for training.
"""
ASPHelper._step(self._optimizer)
@dygraph_only
def state_dict(self):
r"""
This function is a decorator of `state_dict` function in `Optimizer`.
Returns:
state_dict(dict) : dict contains all the Tensor used by optimizer
"""
state_dict = self._optimizer.state_dict()
asp_info = ASPHelper._get_program_asp_info(
paddle.static.default_main_program())
for param_name, var in asp_info.mask_vars.items():
state_dict.update({ASPHelper._get_mask_name(param_name): var})
return state_dict
@dygraph_only
def set_state_dict(self, state_dict):
r"""
This function is a decorator of `set_state_dict` function in `Optimizer`.
Args:
state_dict(dict) : Dict contains all the Tensor needed by optimizer
Return:
None
"""
asp_info = ASPHelper._get_program_asp_info(
paddle.static.default_main_program())
for param_name, var in asp_info.mask_vars.items():
param_mask_name = ASPHelper._get_mask_name(param_name)
assert param_mask_name in state_dict, \
"The {} is not found.".format(param_mask_name)
var.set_value(state_dict[param_mask_name])
asp_info.update_masks(param_name, var.numpy())
return self._optimizer.set_state_dict(state_dict)
...@@ -94,13 +94,12 @@ def calculate_density(x): ...@@ -94,13 +94,12 @@ def calculate_density(x):
float: The density of :attr:`x`. float: The density of :attr:`x`.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle
import numpy as np import numpy as np
import paddle.static.sparsity as sparsity
x = np.array([[0, 1, 3, 0], x = np.array([[0, 1, 3, 0],
[1, 1, 0, 1]]) [1, 1, 0, 1]])
sparsity.calculate_density(x) # 0.625 paddle.incubate.asp.calculate_density(x) # 0.625
""" """
x_flattened = x.flatten() x_flattened = x.flatten()
return float(np.nonzero(x_flattened)[0].size) / x_flattened.size return float(np.nonzero(x_flattened)[0].size) / x_flattened.size
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp") list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_static")
list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_amp") list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_dynamic")
list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_sharding") list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_sharding")
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
...@@ -10,9 +10,9 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -10,9 +10,9 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP) endforeach(TEST_OP)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_with_asp MODULES test_fleet_with_asp ENVS ${dist_ENVS})
if (WITH_GPU OR WITH_XPU OR WITH_ASCEND OR WITH_ASCEND_CL) if (WITH_GPU OR WITH_XPU OR WITH_ASCEND OR WITH_ASCEND_CL)
py_test_modules(test_fleet_with_asp_amp MODULES test_fleet_with_asp_amp ENVS ${dist_ENVS}) py_test_modules(test_fleet_with_asp_dynamic MODULES test_fleet_with_asp_dynamic ENVS ${dist_ENVS})
py_test_modules(test_fleet_with_asp_static MODULES test_fleet_with_asp_static ENVS ${dist_ENVS})
endif() endif()
endif() endif()
...@@ -21,3 +21,8 @@ if((WITH_DISTRIBUTE) AND (NOT WIN32) AND (NOT APPLE)) ...@@ -21,3 +21,8 @@ if((WITH_DISTRIBUTE) AND (NOT WIN32) AND (NOT APPLE))
py_test_modules(test_fleet_with_asp_sharding MODULES test_fleet_with_asp_sharding ENVS ${dist_ENVS}) py_test_modules(test_fleet_with_asp_sharding MODULES test_fleet_with_asp_sharding ENVS ${dist_ENVS})
endif() endif()
endif() endif()
set_tests_properties(test_asp_pruning_dynamic PROPERTIES TIMEOUT 30)
set_tests_properties(test_asp_pruning_static PROPERTIES TIMEOUT 30)
set_tests_properties(test_asp_optimize_dynamic PROPERTIES TIMEOUT 30)
set_tests_properties(test_asp_optimize_static PROPERTIES TIMEOUT 30)
...@@ -20,7 +20,6 @@ import threading, time ...@@ -20,7 +20,6 @@ import threading, time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.static import sparsity
from paddle.fluid.contrib.sparsity.asp import ASPHelper from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np import numpy as np
...@@ -60,7 +59,7 @@ class TestASPHelperPruningBase(unittest.TestCase): ...@@ -60,7 +59,7 @@ class TestASPHelperPruningBase(unittest.TestCase):
loss = fluid.layers.mean( loss = fluid.layers.mean(
fluid.layers.cross_entropy( fluid.layers.cross_entropy(
input=self.predict, label=self.label)) input=self.predict, label=self.label))
optimizer = sparsity.decorate( optimizer = paddle.incubate.asp.decorate(
fluid.optimizer.SGD(learning_rate=0.01)) fluid.optimizer.SGD(learning_rate=0.01))
optimizer.minimize(loss, self.startup_program) optimizer.minimize(loss, self.startup_program)
...@@ -75,7 +74,7 @@ class TestASPHelperPruningBase(unittest.TestCase): ...@@ -75,7 +74,7 @@ class TestASPHelperPruningBase(unittest.TestCase):
def __pruning_and_checking(self, exe, place, mask_func_name, def __pruning_and_checking(self, exe, place, mask_func_name,
check_func_name, with_mask): check_func_name, with_mask):
exe.run(self.startup_program) exe.run(self.startup_program)
sparsity.prune_model( paddle.incubate.asp.prune_model(
self.main_program, mask_algo=mask_func_name, with_mask=with_mask) self.main_program, mask_algo=mask_func_name, with_mask=with_mask)
for param in self.main_program.global_block().all_parameters(): for param in self.main_program.global_block().all_parameters():
if ASPHelper._is_supported_layer(self.main_program, param.name): if ASPHelper._is_supported_layer(self.main_program, param.name):
......
...@@ -66,6 +66,97 @@ class TestASPAddSupportedLayer(unittest.TestCase): ...@@ -66,6 +66,97 @@ class TestASPAddSupportedLayer(unittest.TestCase):
my_own_layer_name in supported_layers_and_prune_func_map) my_own_layer_name in supported_layers_and_prune_func_map)
class TestASPDynamicCustomerizedPruneFunc(unittest.TestCase):
def setUp(self):
paddle.disable_static()
class CustomerLayer(paddle.nn.Layer):
def __init__(self):
super(CustomerLayer, self).__init__()
self.weight = self.create_parameter(
shape=[32, 32], attr=None, dtype='float32', is_bias=False)
self.linear1 = paddle.nn.Linear(32, 32)
self.linear2 = paddle.nn.Linear(32, 10)
def forward(self, input_):
hidden = paddle.nn.functional.linear(
x=input_, weight=self.weight)
hidden = self.linear1(hidden)
out = self.linear2(hidden)
return out
sparsity.add_supported_layer(CustomerLayer, my_own_pruning)
self.layer = CustomerLayer()
self.customer_prefix = paddle.fluid.dygraph.layers._convert_camel_to_snake(
CustomerLayer.__name__)
self.supported_layer_count_ref = 3
def test_inference_pruning(self):
sparsity.prune_model(self.layer, mask_algo="mask_1d", with_mask=False)
supported_layer_count = 0
for param in self.layer.parameters():
mat = param.numpy()
if sparsity.asp.ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
supported_layer_count += 1
if (self.customer_prefix in param.name):
self.assertLessEqual(
np.sum(mat.flatten() - static_tensor.flatten()), 1e-4)
else:
self.assertTrue(
sparsity.check_sparsity(
mat.T,
func_name=sparsity.CheckMethod.CHECK_1D,
n=2,
m=4))
self.assertEqual(supported_layer_count, self.supported_layer_count_ref)
def test_training_pruning(self):
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=self.layer.parameters())
optimizer = sparsity.decorate(optimizer)
sparsity.prune_model(self.layer, mask_algo="mask_1d", with_mask=True)
supported_layer_count = 0
for param in self.layer.parameters():
mat = param.numpy()
if sparsity.asp.ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat_mask = sparsity.asp.ASPHelper._get_program_asp_info(
paddle.static.default_main_program()).mask_vars[
param.name].numpy()
supported_layer_count += 1
if (self.customer_prefix in param.name):
self.assertLessEqual(
np.sum(mat.flatten() - static_tensor.flatten()), 1e-4)
self.assertLessEqual(
np.sum(mat_mask.flatten() - static_tensor_mask.flatten(
)), 1e-4)
else:
self.assertTrue(
sparsity.check_sparsity(
mat.T,
func_name=sparsity.CheckMethod.CHECK_1D,
n=2,
m=4))
self.assertTrue(
sparsity.check_sparsity(
mat_mask.T,
func_name=sparsity.CheckMethod.CHECK_1D,
n=2,
m=4))
self.assertEqual(supported_layer_count, self.supported_layer_count_ref)
class TestASPStaticCustomerizedPruneFunc(unittest.TestCase): class TestASPStaticCustomerizedPruneFunc(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.enable_static() paddle.enable_static()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=2, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(1352, 32)
self.linear2 = paddle.nn.Linear(32, 32)
self.linear3 = paddle.nn.Linear(32, 10)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
hidden = self.linear1(hidden)
hidden = self.linear2(hidden)
prediction = self.linear3(hidden)
return prediction
class TestASPDynamicOptimize(unittest.TestCase):
def setUp(self):
self.layer = MyLayer()
self.place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
self.place = paddle.CUDAPlace(0)
self.optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=self.layer.parameters())
def test_is_supported_layers(self):
program = paddle.static.default_main_program()
names = [
'embedding_0.w_0', 'fack_layer_0.w_0', 'conv2d_0.w_0',
'conv2d_0.b_0', 'conv2d_1.w_0', 'conv2d_1.b_0', 'fc_0.w_0',
'fc_0.b_0', 'fc_1.w_0', 'fc_1.b_0', 'linear_2.w_0', 'linear_2.b_0'
]
ref = [
False, False, True, False, True, False, True, False, True, False,
True, False
]
for i, name in enumerate(names):
self.assertTrue(
ref[i] == ASPHelper._is_supported_layer(program, name))
paddle.incubate.asp.set_excluded_layers(['fc_1', 'conv2d_0'])
ref = [
False, False, False, False, True, False, True, False, False, False,
True, False
]
for i, name in enumerate(names):
self.assertTrue(
ref[i] == ASPHelper._is_supported_layer(program, name))
paddle.incubate.asp.reset_excluded_layers()
ref = [
False, False, True, False, True, False, True, False, True, False,
True, False
]
for i, name in enumerate(names):
self.assertTrue(
ref[i] == ASPHelper._is_supported_layer(program, name))
def test_decorate(self):
param_names = [param.name for param in self.layer.parameters()]
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
program = paddle.static.default_main_program()
for name in param_names:
mask_var = ASPHelper._get_program_asp_info(program).mask_vars.get(
name, None)
if ASPHelper._is_supported_layer(program, name):
self.assertTrue(mask_var is not None)
else:
self.assertTrue(mask_var is None)
def test_asp_training(self):
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
paddle.incubate.asp.prune_model(self.layer)
imgs = paddle.to_tensor(
np.random.randn(32, 3, 24, 24),
dtype='float32',
place=self.place,
stop_gradient=False)
labels = paddle.to_tensor(
np.random.randint(
10, size=(32, 1)),
dtype='float32',
place=self.place,
stop_gradient=False)
loss_fn = paddle.nn.MSELoss(reduction='mean')
output = self.layer(imgs)
loss = loss_fn(output, labels)
loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()
for param in self.layer.parameters():
if ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat = param.numpy()
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
def test_asp_training_with_amp(self):
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
paddle.incubate.asp.prune_model(self.layer)
imgs = paddle.to_tensor(
np.random.randn(32, 3, 24, 24),
dtype='float32',
place=self.place,
stop_gradient=False)
labels = paddle.to_tensor(
np.random.randint(
10, size=(32, 1)),
dtype='float32',
place=self.place,
stop_gradient=False)
loss_fn = paddle.nn.MSELoss(reduction='mean')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
with paddle.amp.auto_cast(enable=True):
output = self.layer(imgs)
loss = loss_fn(output, labels)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(self.optimizer, scaled)
self.optimizer.clear_grad()
for param in self.layer.parameters():
if ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat = param.numpy()
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,21 +20,20 @@ import threading, time ...@@ -20,21 +20,20 @@ import threading, time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.static import sparsity
from paddle.fluid.contrib.sparsity.asp import ASPHelper from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np import numpy as np
paddle.enable_static() paddle.enable_static()
class TestASPHelper(unittest.TestCase): class TestASPStaticOptimize(unittest.TestCase):
def setUp(self): def setUp(self):
self.main_program = fluid.Program() self.main_program = fluid.Program()
self.startup_program = fluid.Program() self.startup_program = fluid.Program()
def build_model(): def build_model():
img = fluid.data( img = fluid.data(
name='img', shape=[None, 3, 32, 32], dtype='float32') name='img', shape=[None, 3, 24, 24], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64') label = fluid.data(name='label', shape=[None, 1], dtype='int64')
hidden = fluid.layers.conv2d( hidden = fluid.layers.conv2d(
input=img, num_filters=4, filter_size=3, padding=2, act="relu") input=img, num_filters=4, filter_size=3, padding=2, act="relu")
...@@ -87,7 +86,7 @@ class TestASPHelper(unittest.TestCase): ...@@ -87,7 +86,7 @@ class TestASPHelper(unittest.TestCase):
self.assertTrue( self.assertTrue(
ref[i] == ASPHelper._is_supported_layer(program, name)) ref[i] == ASPHelper._is_supported_layer(program, name))
sparsity.set_excluded_layers(program, ['fc_1', 'conv2d_0']) paddle.incubate.asp.set_excluded_layers(['fc_1', 'conv2d_0'], program)
ref = [ ref = [
False, False, False, False, True, False, True, False, False, False, False, False, False, False, True, False, True, False, False, False,
True, False True, False
...@@ -96,7 +95,7 @@ class TestASPHelper(unittest.TestCase): ...@@ -96,7 +95,7 @@ class TestASPHelper(unittest.TestCase):
self.assertTrue( self.assertTrue(
ref[i] == ASPHelper._is_supported_layer(program, name)) ref[i] == ASPHelper._is_supported_layer(program, name))
sparsity.reset_excluded_layers(program) paddle.incubate.asp.reset_excluded_layers(program)
ref = [ ref = [
False, False, True, False, True, False, True, False, True, False, False, False, True, False, True, False, True, False, True, False,
True, False True, False
...@@ -109,7 +108,7 @@ class TestASPHelper(unittest.TestCase): ...@@ -109,7 +108,7 @@ class TestASPHelper(unittest.TestCase):
param_names = self.__get_param_names(self.main_program.global_block() param_names = self.__get_param_names(self.main_program.global_block()
.all_parameters()) .all_parameters())
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.optimizer = sparsity.decorate(self.optimizer) self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
self.optimizer.minimize(self.loss, self.startup_program) self.optimizer.minimize(self.loss, self.startup_program)
param_names_after_minimize = self.__get_param_names( param_names_after_minimize = self.__get_param_names(
self.main_program.global_block().all_parameters()) self.main_program.global_block().all_parameters())
...@@ -119,7 +118,7 @@ class TestASPHelper(unittest.TestCase): ...@@ -119,7 +118,7 @@ class TestASPHelper(unittest.TestCase):
def test_asp_training(self): def test_asp_training(self):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.optimizer = sparsity.decorate(self.optimizer) self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
self.optimizer.minimize(self.loss, self.startup_program) self.optimizer.minimize(self.loss, self.startup_program)
place = paddle.CPUPlace() place = paddle.CPUPlace()
...@@ -129,10 +128,10 @@ class TestASPHelper(unittest.TestCase): ...@@ -129,10 +128,10 @@ class TestASPHelper(unittest.TestCase):
feeder = fluid.DataFeeder(feed_list=[self.img, self.label], place=place) feeder = fluid.DataFeeder(feed_list=[self.img, self.label], place=place)
exe.run(self.startup_program) exe.run(self.startup_program)
sparsity.prune_model(self.main_program) paddle.incubate.asp.prune_model(self.main_program)
data = (np.random.randn(64, 3, 32, 32), np.random.randint( data = (np.random.randn(32, 3, 24, 24), np.random.randint(
10, size=(64, 1))) 10, size=(32, 1)))
exe.run(self.main_program, feed=feeder.feed([data])) exe.run(self.main_program, feed=feeder.feed([data]))
for param in self.main_program.global_block().all_parameters(): for param in self.main_program.global_block().all_parameters():
...@@ -149,7 +148,7 @@ class TestASPHelper(unittest.TestCase): ...@@ -149,7 +148,7 @@ class TestASPHelper(unittest.TestCase):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.optimizer = fluid.contrib.mixed_precision.decorator.decorate( self.optimizer = fluid.contrib.mixed_precision.decorator.decorate(
self.optimizer) self.optimizer)
self.optimizer = sparsity.decorate(self.optimizer) self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
self.optimizer.minimize(self.loss, self.startup_program) self.optimizer.minimize(self.loss, self.startup_program)
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -157,10 +156,10 @@ class TestASPHelper(unittest.TestCase): ...@@ -157,10 +156,10 @@ class TestASPHelper(unittest.TestCase):
feed_list=[self.img, self.label], place=place) feed_list=[self.img, self.label], place=place)
exe.run(self.startup_program) exe.run(self.startup_program)
sparsity.prune_model(self.main_program) paddle.incubate.asp.prune_model(self.main_program)
data = (np.random.randn(64, 3, 32, 32), np.random.randint( data = (np.random.randn(32, 3, 24, 24), np.random.randint(
10, size=(64, 1))) 10, size=(32, 1)))
exe.run(self.main_program, feed=feeder.feed([data])) exe.run(self.main_program, feed=feeder.feed([data]))
for param in self.main_program.global_block().all_parameters(): for param in self.main_program.global_block().all_parameters():
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import unittest
from paddle.static import sparsity
from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase
paddle.enable_static()
class TestASPHelperPruning2DBest(TestASPHelperPruningBase):
def test_2D_best_inference_pruning(self):
self.run_inference_pruning_test(
'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D)
def test_2D_best_training_pruning(self):
self.run_training_pruning_test(
'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.static import sparsity
from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase
paddle.enable_static()
class TestASPHelperPruning2DGreedy(TestASPHelperPruningBase):
def test_2D_greedy_inference_pruning(self):
self.run_inference_pruning_test(
'mask_2d_greedy',
paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D)
def test_2D_greedy_training_pruning(self):
self.run_training_pruning_test(
'mask_2d_greedy',
paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.fluid.contrib.sparsity.asp import ASPHelper
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=2, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(1352, 32)
self.linear2 = paddle.nn.Linear(32, 10)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
hidden = self.linear1(hidden)
prediction = self.linear2(hidden)
return prediction
class TestASPDynamicPruningBase(unittest.TestCase):
def setUp(self):
self.layer = MyLayer()
place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
self.img = paddle.to_tensor(
np.random.uniform(
low=-0.5, high=0.5, size=(32, 3, 24, 24)),
dtype=np.float32,
place=place,
stop_gradient=False)
self.set_config()
def set_config(self):
self.mask_gen_func = 'mask_1d'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D
def test_inference_pruning(self):
self.__pruning_and_checking(False)
def test_training_pruning(self):
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=self.layer.parameters())
optimizer = paddle.incubate.asp.decorate(optimizer)
self.__pruning_and_checking(True)
def __pruning_and_checking(self, with_mask):
paddle.incubate.asp.prune_model(
self.layer, mask_algo=self.mask_gen_func, with_mask=with_mask)
for param in self.layer.parameters():
if ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat = param.numpy()
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, func_name=self.mask_check_func, n=2, m=4))
class TestASPDynamicPruning1D(TestASPDynamicPruningBase):
def set_config(self):
self.mask_gen_func = 'mask_1d'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D
class TestASPDynamicPruning2DBest(TestASPDynamicPruningBase):
def set_config(self):
self.mask_gen_func = 'mask_2d_best'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D
class TestASPDynamicPruning2DGreedy(TestASPDynamicPruningBase):
def set_config(self):
self.mask_gen_func = 'mask_2d_greedy'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import threading, time
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np
paddle.enable_static()
class TestASPStaticPruningBase(unittest.TestCase):
def setUp(self):
self.main_program = fluid.Program()
self.startup_program = fluid.Program()
def build_model():
img = fluid.data(
name='img', shape=[None, 3, 24, 24], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
hidden = fluid.layers.conv2d(
input=img, num_filters=2, filter_size=3, padding=2, act="relu")
hidden = fluid.layers.fc(input=hidden, size=32, act='softmax')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
return img, label, prediction
with fluid.program_guard(self.main_program, self.startup_program):
self.img, self.label, self.predict = build_model()
self.set_config()
def set_config(self):
self.mask_gen_func = 'mask_1d'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D
def test_inference_pruning(self):
place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = fluid.Executor(place)
self.__pruning_and_checking(exe, place, False)
def test_training_pruning(self):
with fluid.program_guard(self.main_program, self.startup_program):
loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=self.predict, label=self.label))
optimizer = paddle.incubate.asp.decorate(
fluid.optimizer.SGD(learning_rate=0.01))
optimizer.minimize(loss, self.startup_program)
place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = fluid.Executor(place)
self.__pruning_and_checking(exe, place, True)
def __pruning_and_checking(self, exe, place, with_mask):
exe.run(self.startup_program)
paddle.incubate.asp.prune_model(
self.main_program,
mask_algo=self.mask_gen_func,
with_mask=with_mask)
for param in self.main_program.global_block().all_parameters():
if ASPHelper._is_supported_layer(self.main_program, param.name):
mat = np.array(fluid.global_scope().find_var(param.name)
.get_tensor())
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, func_name=self.mask_check_func, n=2, m=4))
class TestASPStaticPruning1D(TestASPStaticPruningBase):
def set_config(self):
self.mask_gen_func = 'mask_1d'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D
class TestASPStaticPruning2DBest(TestASPStaticPruningBase):
def set_config(self):
self.mask_gen_func = 'mask_2d_best'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D
class TestASPStaticPruning2DGreedy(TestASPStaticPruningBase):
def set_config(self):
self.mask_gen_func = 'mask_2d_greedy'
self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = paddle.nn.Conv2D(
in_channels=3, out_channels=4, kernel_size=3, padding=2)
self.linear1 = paddle.nn.Linear(4624, 32)
self.linear2 = paddle.nn.Linear(32, 32)
self.linear3 = paddle.nn.Linear(32, 10)
def forward(self, img):
hidden = self.conv1(img)
hidden = paddle.flatten(hidden, start_axis=1)
hidden = self.linear1(hidden)
hidden = self.linear2(hidden)
prediction = self.linear3(hidden)
return prediction
class TestASPDynamicOptimize(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.layer = MyLayer()
self.place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
self.place = paddle.CUDAPlace(0)
self.optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=self.layer.parameters())
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
paddle.incubate.asp.prune_model(self.layer)
def test_save_and_load(self):
path = "/tmp/paddle_asp_save_dy/"
net_path = path + "asp_net.pdparams"
opt_path = path + "asp_opt.pdopt"
paddle.save(self.layer.state_dict(), net_path)
paddle.save(self.optimizer.state_dict(), opt_path)
asp_info = ASPHelper._get_program_asp_info(
paddle.static.default_main_program())
for param_name in asp_info.mask_vars:
mask = asp_info.mask_vars[param_name]
asp_info.update_mask_vars(
param_name, paddle.ones(
shape=mask.shape, dtype=mask.dtype))
asp_info.update_masks(param_name, np.ones(shape=mask.shape))
net_state_dict = paddle.load(net_path)
opt_state_dict = paddle.load(opt_path)
self.layer.set_state_dict(net_state_dict)
self.optimizer.set_state_dict(opt_state_dict)
imgs = paddle.to_tensor(
np.random.randn(64, 3, 32, 32),
dtype='float32',
place=self.place,
stop_gradient=False)
labels = paddle.to_tensor(
np.random.randint(
10, size=(64, 1)),
dtype='float32',
place=self.place,
stop_gradient=False)
loss_fn = paddle.nn.MSELoss(reduction='mean')
output = self.layer(imgs)
loss = loss_fn(output, labels)
loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()
for param in self.layer.parameters():
if ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat = param.numpy()
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
class TestASPStaticOptimize(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.main_program = fluid.Program()
self.startup_program = fluid.Program()
def build_model():
img = fluid.data(
name='img', shape=[None, 3, 32, 32], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
hidden = fluid.layers.conv2d(
input=img, num_filters=4, filter_size=3, padding=2, act="relu")
hidden = fluid.layers.fc(input=hidden, size=32, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
return img, label, prediction
with fluid.program_guard(self.main_program, self.startup_program):
self.img, self.label, predict = build_model()
self.loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=predict, label=self.label))
self.optimizer = fluid.optimizer.SGD(learning_rate=0.01)
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
self.optimizer.minimize(self.loss, self.startup_program)
self.place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
self.place = paddle.CUDAPlace(0)
self.exe = fluid.Executor(self.place)
self.exe.run(self.startup_program)
paddle.incubate.asp.prune_model(self.main_program)
def test_save_and_load(self):
path = "/tmp/paddle_asp_save_st/"
param_path = path + "asp.pdparams"
model_path = path + "asp.pdmodel"
paddle.save(self.main_program.state_dict(), param_path)
paddle.save(self.main_program, model_path)
prog = paddle.load(model_path)
state_dict = paddle.load(param_path)
prog.set_state_dict(state_dict)
feeder = fluid.DataFeeder(
feed_list=[self.img, self.label], place=self.place)
data = (np.random.randn(64, 3, 32, 32), np.random.randint(
10, size=(64, 1)))
self.exe.run(prog, feed=feeder.feed([data]))
for param in prog.global_block().all_parameters():
if ASPHelper._is_supported_layer(prog, param.name):
mat = np.array(fluid.global_scope().find_var(param.name)
.get_tensor())
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
if __name__ == '__main__':
unittest.main()
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import unittest import unittest
import threading, time import threading, time
import paddle import paddle
from paddle.static import sparsity
import numpy as np import numpy as np
...@@ -41,9 +40,9 @@ class TestASPUtils(unittest.TestCase): ...@@ -41,9 +40,9 @@ class TestASPUtils(unittest.TestCase):
x = np.array([[1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], x = np.array([[1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 0.0, 1.0]]) [0.0, 1.0, 0.0, 0.0, 1.0]])
self.assertEqual(sparsity.calculate_density(x), 0.56) self.assertEqual(paddle.incubate.asp.calculate_density(x), 0.56)
x[:, 0] = 0.0 x[:, 0] = 0.0
self.assertEqual(sparsity.calculate_density(x), 0.4) self.assertEqual(paddle.incubate.asp.calculate_density(x), 0.4)
def test_check_mask_1d(self): def test_check_mask_1d(self):
x = np.array([[1.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], x = np.array([[1.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0],
...@@ -219,3 +218,7 @@ class TestASPUtils(unittest.TestCase): ...@@ -219,3 +218,7 @@ class TestASPUtils(unittest.TestCase):
func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D, func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D,
n=2, n=2,
m=4)) m=4))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import os
from paddle.static import sparsity
from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
if cuda_visible_devices is None or cuda_visible_devices == "":
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0]
paddle.enable_static()
class TestFleetWithASP(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_TRAINER_ID"] = "0"
def net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog):
input_x = paddle.static.data(
name="x", shape=[-1, 32], dtype='float32')
input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64')
fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh')
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.asp = True
return avg_cost, strategy, input_x, input_y
def test_with_asp(self):
fleet.init(is_collective=True)
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy, input_x, input_y = self.net(train_prog,
startup_prog)
with fluid.program_guard(train_prog, startup_prog):
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place)
exe.run(startup_prog)
sparsity.prune_model(train_prog)
data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1)))
exe.run(train_prog, feed=feeder.feed([data]))
for param in train_prog.global_block().all_parameters():
if ASPHelper._is_supported_layer(train_prog, param.name):
mat = np.array(fluid.global_scope().find_var(param.name)
.get_tensor())
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import os
from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
if cuda_visible_devices is None or cuda_visible_devices == "":
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0]
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.linear1 = paddle.nn.Linear(32, 32)
self.linear2 = paddle.nn.Linear(32, 10)
def forward(self, x):
hidden = self.linear1(x)
prediction = self.linear2(hidden)
return prediction
class TestFleetWithASPDynamic(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_TRAINER_ID"] = "0"
self.layer = MyLayer()
self.place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
self.place = paddle.CUDAPlace(0)
self.optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=self.layer.parameters())
def test_with_asp(self):
fleet.init(is_collective=True)
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
paddle.incubate.asp.prune_model(self.layer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
self.layer = fleet.distributed_model(self.layer)
imgs = paddle.to_tensor(
np.random.randn(64, 32),
dtype='float32',
place=self.place,
stop_gradient=False)
labels = paddle.to_tensor(
np.random.randint(
10, size=(64, 1)),
dtype='float32',
place=self.place,
stop_gradient=False)
loss_fn = paddle.nn.MSELoss(reduction='mean')
output = self.layer(imgs)
loss = loss_fn(output, labels)
loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()
for param in self.layer.parameters():
if ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat = param.numpy()
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
class TestFleetWithASPAMPDynamic(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_TRAINER_ID"] = "0"
self.layer = MyLayer()
self.place = paddle.CPUPlace()
if core.is_compiled_with_cuda():
self.place = paddle.CUDAPlace(0)
self.optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=self.layer.parameters())
def test_with_asp(self):
fleet.init(is_collective=True)
self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
paddle.incubate.asp.prune_model(self.layer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
self.layer = fleet.distributed_model(self.layer)
imgs = paddle.to_tensor(
np.random.randn(64, 32),
dtype='float32',
place=self.place,
stop_gradient=False)
labels = paddle.to_tensor(
np.random.randint(
10, size=(64, 1)),
dtype='float32',
place=self.place,
stop_gradient=False)
loss_fn = paddle.nn.MSELoss(reduction='mean')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
with paddle.amp.auto_cast(enable=True):
output = self.layer(imgs)
loss = loss_fn(output, labels)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(self.optimizer, scaled)
self.optimizer.clear_grad()
for param in self.layer.parameters():
if ASPHelper._is_supported_layer(
paddle.static.default_main_program(), param.name):
mat = param.numpy()
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -32,7 +32,62 @@ else: ...@@ -32,7 +32,62 @@ else:
paddle.enable_static() paddle.enable_static()
class TestFleetWithASP(unittest.TestCase): class TestFleetWithASPStatic(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_TRAINER_ID"] = "0"
def net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog):
input_x = paddle.static.data(
name="x", shape=[-1, 32], dtype='float32')
input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64')
fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh')
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.asp = True
return avg_cost, strategy, input_x, input_y
def test_with_asp(self):
fleet.init(is_collective=True)
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy, input_x, input_y = self.net(train_prog,
startup_prog)
with fluid.program_guard(train_prog, startup_prog):
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place)
exe.run(startup_prog)
sparsity.prune_model(train_prog)
data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1)))
exe.run(train_prog, feed=feeder.feed([data]))
for param in train_prog.global_block().all_parameters():
if ASPHelper._is_supported_layer(train_prog, param.name):
mat = np.array(fluid.global_scope().find_var(param.name)
.get_tensor())
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
class TestFleetWithASPAMPStatic(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -32,6 +32,7 @@ import paddle.incubate.autograd ...@@ -32,6 +32,7 @@ import paddle.incubate.autograd
import paddle.incubate.autotune import paddle.incubate.autotune
from . import nn #noqa: F401 from . import nn #noqa: F401
from . import asp #noqa: F401
__all__ = [ __all__ = [
'LookAhead', 'LookAhead',
......
...@@ -13,25 +13,16 @@ ...@@ -13,25 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from ...fluid.contrib.sparsity import calculate_density #noqa: F401
from ...fluid.contrib.sparsity import decorate #noqa: F401
import unittest from ...fluid.contrib.sparsity import prune_model #noqa: F401
import paddle from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401
from paddle.static import sparsity from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401
from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase
__all__ = [ #noqa
paddle.enable_static() 'calculate_density',
'decorate',
'prune_model',
class TestASPHelperPruning1D(TestASPHelperPruningBase): 'set_excluded_layers',
def test_1D_inference_pruning(self): 'reset_excluded_layers'
self.run_inference_pruning_test( ]
'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D)
def test_1D_training_pruning(self):
self.run_training_pruning_test(
'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D)
if __name__ == '__main__':
unittest.main()
...@@ -16,8 +16,14 @@ ...@@ -16,8 +16,14 @@
from ...fluid.contrib.sparsity import calculate_density #noqa: F401 from ...fluid.contrib.sparsity import calculate_density #noqa: F401
from ...fluid.contrib.sparsity import decorate #noqa: F401 from ...fluid.contrib.sparsity import decorate #noqa: F401
from ...fluid.contrib.sparsity import prune_model #noqa: F401 from ...fluid.contrib.sparsity import prune_model #noqa: F401
from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401
from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401
from ...fluid.contrib import sparsity #noqa: F401
def set_excluded_layers(main_program, param_names):
sparsity.set_excluded_layers(
param_names=param_names, main_program=main_program)
__all__ = [ #noqa __all__ = [ #noqa
'calculate_density', 'calculate_density',
......
...@@ -281,6 +281,7 @@ packages=['paddle', ...@@ -281,6 +281,7 @@ packages=['paddle',
'paddle.incubate.tensor', 'paddle.incubate.tensor',
'paddle.incubate.multiprocessing', 'paddle.incubate.multiprocessing',
'paddle.incubate.nn', 'paddle.incubate.nn',
'paddle.incubate.asp',
'paddle.incubate.passes', 'paddle.incubate.passes',
'paddle.distribution', 'paddle.distribution',
'paddle.distributed.sharding', 'paddle.distributed.sharding',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册