diff --git a/paddleslim/__init__.py b/paddleslim/__init__.py index ed7cf9f9e95f4706516970eb167e8f8c38517724..9d3123290ee0963d1f0cd5884be7909eaa521e21 100644 --- a/paddleslim/__init__.py +++ b/paddleslim/__init__.py @@ -24,3 +24,8 @@ from paddleslim import dygraph __all__ = [ 'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph' ] + +from paddleslim.dygraph import * +__all__ += dygraph.__all__ +from paddleslim.analysis import * +__all__ += analysis.__all__ diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index 194efadec2ebc22016c8847e7c872b74a9544dc3..a55e175aa5ff6adc1949ceb1c9b0bbaacb2afc5b 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -18,7 +18,34 @@ from ..core import GraphWrapper, dygraph2program __all__ = ["flops", "dygraph_flops"] -def flops(program, only_conv=True, detail=False): +def flops(model, inputs=None, dtypes=None, only_conv=True, detail=False): + """ + Compute the FLOPs of nn.Layer of paddle.Program. + Args: + model(paddle.nn.Layer|paddle.static.Program): The target model. + inputs(list): It is only used when model is instance of 'paddle.nn.Layer'. The dummy inputs used for 'model.forward'. It can be: + 1. list<int>|tuple<int>: means 'model.forward' accepts + only one variable as argument and the shape of + variable is 'inputs'. + 2. list<list<list>>: means 'model.forward' accepts multiple + variables as arguments and the shapes of variables is 'inputs'. + 3. others: 'inputs' will be used as argument list by calling + 'model.forward(*inputs)'. + dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means + data type of each input. None means all the inputs is 'float32'. + Default: None. + only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true. + default: True. + detail(bool): Whether to return detail of each convolution layer. + """ + if isinstance(model, paddle.static.Program): + return _static_flops(model, only_conv=only_conv, detail=detail) + elif isinstance(model, paddle.nn.Layer): + return dygraph_flops( + model, inputs, dtypes=dtypes, only_conv=only_conv, detail=detail) + + +def _static_flops(program, only_conv=True, detail=False): """Get FLOPs of target graph. Args: diff --git a/paddleslim/dygraph/__init__.py b/paddleslim/dygraph/__init__.py index 303f7105f8eec5e4b53645d88bf8a6e79f655d27..8104f2d1f987e0fc310e96a5ac30b115cc5cd311 100644 --- a/paddleslim/dygraph/__init__.py +++ b/paddleslim/dygraph/__init__.py @@ -1,24 +1,5 @@ -from . import var_group -from .var_group import * -from . import l1norm_pruner -from .l1norm_pruner import * -from . import pruner -from .pruner import * -from . import filter_pruner -from .filter_pruner import * -from . import l2norm_pruner -from .l2norm_pruner import * -from . import fpgm_pruner -from .fpgm_pruner import * - __all__ = [] - -__all__ += var_group.__all__ -__all__ += l1norm_pruner.__all__ -__all__ += l2norm_pruner.__all__ -__all__ += fpgm_pruner.__all__ -__all__ += pruner.__all__ -__all__ += filter_pruner.__all__ - from .quant import * __all__ += quant.__all__ +from .prune import * +__all__ += prune.__all__ diff --git a/paddleslim/dygraph/prune/__init__.py b/paddleslim/dygraph/prune/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2d29e9295d71dd32520a41a63a6179078b2c5b --- /dev/null +++ b/paddleslim/dygraph/prune/__init__.py @@ -0,0 +1,21 @@ +from . import var_group +from .var_group import * +from . import l1norm_pruner +from .l1norm_pruner import * +from . import pruner +from .pruner import * +from . import filter_pruner +from .filter_pruner import * +from . import l2norm_pruner +from .l2norm_pruner import * +from . import fpgm_pruner +from .fpgm_pruner import * + +__all__ = [] + +__all__ += var_group.__all__ +__all__ += l1norm_pruner.__all__ +__all__ += l2norm_pruner.__all__ +__all__ += fpgm_pruner.__all__ +__all__ += pruner.__all__ +__all__ += filter_pruner.__all__ diff --git a/paddleslim/dygraph/filter_pruner.py b/paddleslim/dygraph/prune/filter_pruner.py similarity index 99% rename from paddleslim/dygraph/filter_pruner.py rename to paddleslim/dygraph/prune/filter_pruner.py index a9c20958b74890d701653679683b0f3a3729f2bc..6c97bbb383c09707d43f1e50c08381f7dcef9941 100644 --- a/paddleslim/dygraph/filter_pruner.py +++ b/paddleslim/dygraph/prune/filter_pruner.py @@ -4,11 +4,11 @@ import numpy as np import pickle import copy import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .pruner import Pruner -from ..analysis import dygraph_flops as flops +from paddleslim.analysis import dygraph_flops as flops from .var_group import VarGroup __all__ = ['Status', 'FilterPruner'] diff --git a/paddleslim/dygraph/fpgm_pruner.py b/paddleslim/dygraph/prune/fpgm_pruner.py similarity index 97% rename from paddleslim/dygraph/fpgm_pruner.py rename to paddleslim/dygraph/prune/fpgm_pruner.py index 1bff3424e1674d6874f443174c9b5eda39c58f5d..cb825a0523c2b44b36da281e02808fbbd64016a6 100644 --- a/paddleslim/dygraph/fpgm_pruner.py +++ b/paddleslim/dygraph/prune/fpgm_pruner.py @@ -1,7 +1,7 @@ import logging import numpy as np import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .filter_pruner import FilterPruner diff --git a/paddleslim/dygraph/l1norm_pruner.py b/paddleslim/dygraph/prune/l1norm_pruner.py similarity index 96% rename from paddleslim/dygraph/l1norm_pruner.py rename to paddleslim/dygraph/prune/l1norm_pruner.py index 9fb2bbb8a748f48380e9a69434affa8c847a9833..358d5fcf40066cd6fd5f6014bb035157e6def5ac 100644 --- a/paddleslim/dygraph/l1norm_pruner.py +++ b/paddleslim/dygraph/prune/l1norm_pruner.py @@ -1,7 +1,7 @@ import logging import numpy as np import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .filter_pruner import FilterPruner diff --git a/paddleslim/dygraph/l2norm_pruner.py b/paddleslim/dygraph/prune/l2norm_pruner.py similarity index 96% rename from paddleslim/dygraph/l2norm_pruner.py rename to paddleslim/dygraph/prune/l2norm_pruner.py index bffdf3a2ba835891c17d55adccab079818188d71..7245392383daa94ac0bb32555741c1380d77e3b8 100644 --- a/paddleslim/dygraph/l2norm_pruner.py +++ b/paddleslim/dygraph/prune/l2norm_pruner.py @@ -1,7 +1,7 @@ import logging import numpy as np import paddle -from ..common import get_logger +from paddleslim.common import get_logger from .var_group import * from .pruning_plan import * from .filter_pruner import FilterPruner diff --git a/paddleslim/dygraph/pruner.py b/paddleslim/dygraph/prune/pruner.py similarity index 97% rename from paddleslim/dygraph/pruner.py rename to paddleslim/dygraph/prune/pruner.py index fe107e1dfcfcab3c2b276f236189314c9c000780..3d5bfe20f35273ae472e9c678ff0a402ad8da57d 100644 --- a/paddleslim/dygraph/pruner.py +++ b/paddleslim/dygraph/prune/pruner.py @@ -3,7 +3,7 @@ import pickle import numpy as np import logging from .pruning_plan import PruningPlan -from ..common import get_logger +from paddleslim.common import get_logger __all__ = ["Pruner"] diff --git a/paddleslim/dygraph/pruning_plan.py b/paddleslim/dygraph/prune/pruning_plan.py similarity index 99% rename from paddleslim/dygraph/pruning_plan.py rename to paddleslim/dygraph/prune/pruning_plan.py index 185d01949c64e76e7c89d282f689330d4d51f5f9..9aa40e7629797f92bbffe437a678fad8e58a2140 100644 --- a/paddleslim/dygraph/pruning_plan.py +++ b/paddleslim/dygraph/prune/pruning_plan.py @@ -2,7 +2,7 @@ import paddle import collections import numpy as np import logging -from ..common import get_logger +from paddleslim.common import get_logger from paddle.fluid import core _logger = get_logger(__name__, level=logging.INFO) diff --git a/paddleslim/dygraph/var_group.py b/paddleslim/dygraph/prune/var_group.py similarity index 93% rename from paddleslim/dygraph/var_group.py rename to paddleslim/dygraph/prune/var_group.py index 894de66293b4ec30f4358047720092a5ec87c995..1f9a01ee7f7aad3c6a73fffaf74ff888d59ca312 100644 --- a/paddleslim/dygraph/var_group.py +++ b/paddleslim/dygraph/prune/var_group.py @@ -2,9 +2,9 @@ import numpy as np import logging import paddle from paddle.fluid.dygraph import TracedLayer -from ..core import GraphWrapper, dygraph2program -from ..prune import collect_convs -from ..common import get_logger +from paddleslim.core import GraphWrapper, dygraph2program +from paddleslim.prune import collect_convs +from paddleslim.common import get_logger __all__ = ["VarGroup"] diff --git a/tests/dygraph/test_flops.py b/tests/dygraph/test_flops.py index 01ffc451b3c44b396fe82e7a33bd55d3cebc89d8..699d95266f431f1b6d9aea30220f71a824b672f1 100644 --- a/tests/dygraph/test_flops.py +++ b/tests/dygraph/test_flops.py @@ -3,7 +3,7 @@ sys.path.append("../../") import unittest import numpy as np import paddle -from paddleslim.analysis import dygraph_flops as flops +from paddleslim import flops from paddle.vision.models import mobilenet_v1, resnet50 from paddle.nn import Conv2D, Layer @@ -16,7 +16,7 @@ class TestFlops(unittest.TestCase): def runTest(self): net = self._net(pretrained=False) - FLOPs = flops(net, (1, 3, 32, 32)) + FLOPs = flops(net, (1, 3, 32, 32), only_conv=False) self.assertTrue(FLOPs == self._gt) @@ -54,7 +54,7 @@ class TestFLOPsCase1(unittest.TestCase): "y": paddle.to_tensor(y), "z": "test" } - FLOPs = flops(net, [inputs]) + FLOPs = flops(net, [inputs], only_conv=False) self.assertTrue(FLOPs == 59184) @@ -67,9 +67,10 @@ class TestFLOPsCase2(unittest.TestCase): y = np.random.uniform(-1, 1, y_shape).astype('float32') inputs = [paddle.to_tensor(x), paddle.to_tensor(y)] - FLOPs1 = flops(net, inputs) + FLOPs1 = flops(net, inputs, only_conv=False) shapes = [x_shape, y_shape] - FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"]) + FLOPs2 = flops( + net, shapes, dtypes=["float32", "float32"], only_conv=False) self.assertTrue(FLOPs1 == FLOPs2) diff --git a/tests/dygraph/test_prune.py b/tests/dygraph/test_prune.py index 64a5b78845a8f3a29250433378f4f8be47d9e445..6f5627510aac87013bf8ca9949fadb6e4a903f94 100644 --- a/tests/dygraph/test_prune.py +++ b/tests/dygraph/test_prune.py @@ -16,7 +16,7 @@ sys.path.append("../../") import unittest import paddle import paddle.fluid as fluid -from paddleslim.dygraph import L1NormFilterPruner +from paddleslim import L1NormFilterPruner from paddle.vision.models import mobilenet_v1, resnet50 from paddleslim.prune import Pruner diff --git a/tests/test_dygraph_pruning_plan.py b/tests/test_dygraph_pruning_plan.py index 88c4d59afaaf6dd10d5baf04fde665883219b0f7..fda40b7dd2879805f4f3538c2c2e00518d8c8dbc 100644 --- a/tests/test_dygraph_pruning_plan.py +++ b/tests/test_dygraph_pruning_plan.py @@ -2,7 +2,7 @@ import sys sys.path.append("../") import unittest import numpy as np -from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask +from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask class TestPruningPlan(unittest.TestCase): diff --git a/tests/test_flops.py b/tests/test_flops.py index f9e4b189bb556b917f83389dd03c44bc30678bd0..b3eaf8bab4885ffe8925b39138323dd6f82cd0d5 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -15,7 +15,7 @@ import sys sys.path.append("../") import unittest import paddle.fluid as fluid -from paddleslim.analysis import flops +from paddleslim import flops from layers import conv_bn_layer from static_case import StaticCase