From 2b780c770042a674377a55b5ec8347da42651137 Mon Sep 17 00:00:00 2001 From: whs <wanghaoshuang@baidu.com> Date: Tue, 2 Feb 2021 10:18:11 +0800 Subject: [PATCH] Move pruning and quant API from paddleslim.dygraph to paddleslim (#633) (#642) --- paddleslim/__init__.py | 5 ++++ paddleslim/analysis/flops.py | 29 ++++++++++++++++++- paddleslim/dygraph/__init__.py | 23 ++------------- paddleslim/dygraph/prune/__init__.py | 21 ++++++++++++++ .../dygraph/{ => prune}/filter_pruner.py | 4 +-- paddleslim/dygraph/{ => prune}/fpgm_pruner.py | 2 +- .../dygraph/{ => prune}/l1norm_pruner.py | 2 +- .../dygraph/{ => prune}/l2norm_pruner.py | 2 +- paddleslim/dygraph/{ => prune}/pruner.py | 2 +- .../dygraph/{ => prune}/pruning_plan.py | 2 +- paddleslim/dygraph/{ => prune}/var_group.py | 6 ++-- tests/dygraph/test_flops.py | 11 +++---- tests/dygraph/test_prune.py | 2 +- tests/test_dygraph_pruning_plan.py | 2 +- tests/test_flops.py | 2 +- 15 files changed, 75 insertions(+), 40 deletions(-) create mode 100644 paddleslim/dygraph/prune/__init__.py rename paddleslim/dygraph/{ => prune}/filter_pruner.py (99%) rename paddleslim/dygraph/{ => prune}/fpgm_pruner.py (97%) rename paddleslim/dygraph/{ => prune}/l1norm_pruner.py (96%) rename paddleslim/dygraph/{ => prune}/l2norm_pruner.py (96%) rename paddleslim/dygraph/{ => prune}/pruner.py (97%) rename paddleslim/dygraph/{ => prune}/pruning_plan.py (99%) rename paddleslim/dygraph/{ => prune}/var_group.py (93%) diff --git a/paddleslim/__init__.py b/paddleslim/__init__.py index ed7cf9f9..9d312329 100644 --- a/paddleslim/__init__.py +++ b/paddleslim/__init__.py @@ -24,3 +24,8 @@ from paddleslim import dygraph __all__ = [ 'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph' ] + +from paddleslim.dygraph import * +__all__ += dygraph.__all__ +from paddleslim.analysis import * +__all__ += analysis.__all__ diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index 194efade..a55e175a 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -18,7 +18,34 @@ from ..core import GraphWrapper, dygraph2program __all__ = ["flops", "dygraph_flops"] -def flops(program, only_conv=True, detail=False): +def flops(model, inputs=None, dtypes=None, only_conv=True, detail=False): + """ + Compute the FLOPs of nn.Layer of paddle.Program. + Args: + model(paddle.nn.Layer|paddle.static.Program): The target model. + inputs(list): It is only used when model is instance of 'paddle.nn.Layer'. The dummy inputs used for 'model.forward'. It can be: + 1. list<int>|tuple<int>: means 'model.forward' accepts + only one variable as argument and the shape of + variable is 'inputs'. + 2. list<list<list>>: means 'model.forward' accepts multiple + variables as arguments and the shapes of variables is 'inputs'. + 3. others: 'inputs' will be used as argument list by calling + 'model.forward(*inputs)'. + dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means + data type of each input. None means all the inputs is 'float32'. + Default: None. + only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true. + default: True. + detail(bool): Whether to return detail of each convolution layer. + """ + if isinstance(model, paddle.static.Program): + return _static_flops(model, only_conv=only_conv, detail=detail) + elif isinstance(model, paddle.nn.Layer): + return dygraph_flops( + model, inputs, dtypes=dtypes, only_conv=only_conv, detail=detail) + + +def _static_flops(program, only_conv=True, detail=False): """Get FLOPs of target graph. Args: diff --git a/paddleslim/dygraph/__init__.py b/paddleslim/dygraph/__init__.py index 303f7105..8104f2d1 100644 --- a/paddleslim/dygraph/__init__.py +++ b/paddleslim/dygraph/__init__.py @@ -1,24 +1,5 @@ -from . import var_group -from .var_group import * -from . import l1norm_pruner -from .l1norm_pruner import * -from . import pruner -from .pruner import * -from . import filter_pruner -from .filter_pruner import * -from . import l2norm_pruner -from .l2norm_pruner import * -from . import fpgm_pruner -from .fpgm_pruner import * - __all__ = [] - -__all__ += var_group.__all__ -__all__ += l1norm_pruner.__all__ -__all__ += l2norm_pruner.__all__ -__all__ += fpgm_pruner.__all__ -__all__ += pruner.__all__ -__all__ += filter_pruner.__all__ - from .quant import * __all__ += quant.__all__ +from .prune import * +__all__ += prune.__all__ diff --git a/paddleslim/dygraph/prune/__init__.py b/paddleslim/dygraph/prune/__init__.py new file mode 100644 index 00000000..fa2d29e9 --- /dev/null +++ b/paddleslim/dygraph/prune/__init__.py @@ -0,0 +1,21 @@ +from . import var_group +from .var_group import * +from . import l1norm_pruner +from .l1norm_pruner import * +from . import pruner +from .pruner import * +from . import filter_pruner +from .filter_pruner import * +from . import l2norm_pruner +from .l2norm_pruner import * +from . import fpgm_pruner +from .fpgm_pruner import * + +__all__ = [] + +__all__ += var_group.__all__ +__all__ += l1norm_pruner.__all__ +__all__ += l2norm_pruner.__all__ +__all__ += fpgm_pruner.__all__ +__all__ += pruner.__all__ +__all__ += filter_pruner.__all__ diff --git a/paddleslim/dygraph/filter_pruner.py b/paddleslim/dygraph/prune/filter_pruner.py similarity index 99% rename from paddleslim/dygraph/filter_pruner.py rename to paddleslim/dygraph/prune/filter_pruner.py index a9c20958..6c97bbb3 100644 --- a/paddleslim/dygraph/filter_pruner.py +++ b/paddleslim/dygraph/prune/filter_pruner.py @@ -4,11 +4,11 @@ import numpy as np import pickle import copy import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .pruner import Pruner -from ..analysis import dygraph_flops as flops +from paddleslim.analysis import dygraph_flops as flops from .var_group import VarGroup __all__ = ['Status', 'FilterPruner'] diff --git a/paddleslim/dygraph/fpgm_pruner.py b/paddleslim/dygraph/prune/fpgm_pruner.py similarity index 97% rename from paddleslim/dygraph/fpgm_pruner.py rename to paddleslim/dygraph/prune/fpgm_pruner.py index 1bff3424..cb825a05 100644 --- a/paddleslim/dygraph/fpgm_pruner.py +++ b/paddleslim/dygraph/prune/fpgm_pruner.py @@ -1,7 +1,7 @@ import logging import numpy as np import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .filter_pruner import FilterPruner diff --git a/paddleslim/dygraph/l1norm_pruner.py b/paddleslim/dygraph/prune/l1norm_pruner.py similarity index 96% rename from paddleslim/dygraph/l1norm_pruner.py rename to paddleslim/dygraph/prune/l1norm_pruner.py index 9fb2bbb8..358d5fcf 100644 --- a/paddleslim/dygraph/l1norm_pruner.py +++ b/paddleslim/dygraph/prune/l1norm_pruner.py @@ -1,7 +1,7 @@ import logging import numpy as np import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .filter_pruner import FilterPruner diff --git a/paddleslim/dygraph/l2norm_pruner.py b/paddleslim/dygraph/prune/l2norm_pruner.py similarity index 96% rename from paddleslim/dygraph/l2norm_pruner.py rename to paddleslim/dygraph/prune/l2norm_pruner.py index bffdf3a2..72453923 100644 --- a/paddleslim/dygraph/l2norm_pruner.py +++ b/paddleslim/dygraph/prune/l2norm_pruner.py @@ -1,7 +1,7 @@ import logging import numpy as np import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .filter_pruner import FilterPruner diff --git a/paddleslim/dygraph/pruner.py b/paddleslim/dygraph/prune/pruner.py similarity index 97% rename from paddleslim/dygraph/pruner.py rename to paddleslim/dygraph/prune/pruner.py index fe107e1d..3d5bfe20 100644 --- a/paddleslim/dygraph/pruner.py +++ b/paddleslim/dygraph/prune/pruner.py @@ -3,7 +3,7 @@ import pickle import numpy as np import logging from .pruning_plan import PruningPlan -from ..common import get_logger +from paddleslim.common import get_logger __all__ = ["Pruner"] diff --git a/paddleslim/dygraph/pruning_plan.py b/paddleslim/dygraph/prune/pruning_plan.py similarity index 99% rename from paddleslim/dygraph/pruning_plan.py rename to paddleslim/dygraph/prune/pruning_plan.py index 185d0194..9aa40e76 100644 --- a/paddleslim/dygraph/pruning_plan.py +++ b/paddleslim/dygraph/prune/pruning_plan.py @@ -2,7 +2,7 @@ import paddle import collections import numpy as np import logging -from ..common import get_logger +from paddleslim.common import get_logger from paddle.fluid import core _logger = get_logger(__name__, level=logging.INFO) diff --git a/paddleslim/dygraph/var_group.py b/paddleslim/dygraph/prune/var_group.py similarity index 93% rename from paddleslim/dygraph/var_group.py rename to paddleslim/dygraph/prune/var_group.py index 894de662..1f9a01ee 100644 --- a/paddleslim/dygraph/var_group.py +++ b/paddleslim/dygraph/prune/var_group.py @@ -2,9 +2,9 @@ import numpy as np import logging import paddle from paddle.fluid.dygraph import TracedLayer -from ..core import GraphWrapper, dygraph2program -from ..prune import collect_convs -from ..common import get_logger +from paddleslim.core import GraphWrapper, dygraph2program +from paddleslim.prune import collect_convs +from paddleslim.common import get_logger __all__ = ["VarGroup"] diff --git a/tests/dygraph/test_flops.py b/tests/dygraph/test_flops.py index 01ffc451..699d9526 100644 --- a/tests/dygraph/test_flops.py +++ b/tests/dygraph/test_flops.py @@ -3,7 +3,7 @@ sys.path.append("../../") import unittest import numpy as np import paddle -from paddleslim.analysis import dygraph_flops as flops +from paddleslim import flops from paddle.vision.models import mobilenet_v1, resnet50 from paddle.nn import Conv2D, Layer @@ -16,7 +16,7 @@ class TestFlops(unittest.TestCase): def runTest(self): net = self._net(pretrained=False) - FLOPs = flops(net, (1, 3, 32, 32)) + FLOPs = flops(net, (1, 3, 32, 32), only_conv=False) self.assertTrue(FLOPs == self._gt) @@ -54,7 +54,7 @@ class TestFLOPsCase1(unittest.TestCase): "y": paddle.to_tensor(y), "z": "test" } - FLOPs = flops(net, [inputs]) + FLOPs = flops(net, [inputs], only_conv=False) self.assertTrue(FLOPs == 59184) @@ -67,9 +67,10 @@ class TestFLOPsCase2(unittest.TestCase): y = np.random.uniform(-1, 1, y_shape).astype('float32') inputs = [paddle.to_tensor(x), paddle.to_tensor(y)] - FLOPs1 = flops(net, inputs) + FLOPs1 = flops(net, inputs, only_conv=False) shapes = [x_shape, y_shape] - FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"]) + FLOPs2 = flops( + net, shapes, dtypes=["float32", "float32"], only_conv=False) self.assertTrue(FLOPs1 == FLOPs2) diff --git a/tests/dygraph/test_prune.py b/tests/dygraph/test_prune.py index 64a5b788..6f562751 100644 --- a/tests/dygraph/test_prune.py +++ b/tests/dygraph/test_prune.py @@ -16,7 +16,7 @@ sys.path.append("../../") import unittest import paddle import paddle.fluid as fluid -from paddleslim.dygraph import L1NormFilterPruner +from paddleslim import L1NormFilterPruner from paddle.vision.models import mobilenet_v1, resnet50 from paddleslim.prune import Pruner diff --git a/tests/test_dygraph_pruning_plan.py b/tests/test_dygraph_pruning_plan.py index 88c4d59a..fda40b7d 100644 --- a/tests/test_dygraph_pruning_plan.py +++ b/tests/test_dygraph_pruning_plan.py @@ -2,7 +2,7 @@ import sys sys.path.append("../") import unittest import numpy as np -from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask +from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask class TestPruningPlan(unittest.TestCase): diff --git a/tests/test_flops.py b/tests/test_flops.py index f9e4b189..b3eaf8ba 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -15,7 +15,7 @@ import sys sys.path.append("../") import unittest import paddle.fluid as fluid -from paddleslim.analysis import flops +from paddleslim import flops from layers import conv_bn_layer from static_case import StaticCase -- GitLab