From 2b780c770042a674377a55b5ec8347da42651137 Mon Sep 17 00:00:00 2001
From: whs <wanghaoshuang@baidu.com>
Date: Tue, 2 Feb 2021 10:18:11 +0800
Subject: [PATCH] Move pruning and quant API from paddleslim.dygraph to
 paddleslim (#633) (#642)

---
 paddleslim/__init__.py                        |  5 ++++
 paddleslim/analysis/flops.py                  | 29 ++++++++++++++++++-
 paddleslim/dygraph/__init__.py                | 23 ++-------------
 paddleslim/dygraph/prune/__init__.py          | 21 ++++++++++++++
 .../dygraph/{ => prune}/filter_pruner.py      |  4 +--
 paddleslim/dygraph/{ => prune}/fpgm_pruner.py |  2 +-
 .../dygraph/{ => prune}/l1norm_pruner.py      |  2 +-
 .../dygraph/{ => prune}/l2norm_pruner.py      |  2 +-
 paddleslim/dygraph/{ => prune}/pruner.py      |  2 +-
 .../dygraph/{ => prune}/pruning_plan.py       |  2 +-
 paddleslim/dygraph/{ => prune}/var_group.py   |  6 ++--
 tests/dygraph/test_flops.py                   | 11 +++----
 tests/dygraph/test_prune.py                   |  2 +-
 tests/test_dygraph_pruning_plan.py            |  2 +-
 tests/test_flops.py                           |  2 +-
 15 files changed, 75 insertions(+), 40 deletions(-)
 create mode 100644 paddleslim/dygraph/prune/__init__.py
 rename paddleslim/dygraph/{ => prune}/filter_pruner.py (99%)
 rename paddleslim/dygraph/{ => prune}/fpgm_pruner.py (97%)
 rename paddleslim/dygraph/{ => prune}/l1norm_pruner.py (96%)
 rename paddleslim/dygraph/{ => prune}/l2norm_pruner.py (96%)
 rename paddleslim/dygraph/{ => prune}/pruner.py (97%)
 rename paddleslim/dygraph/{ => prune}/pruning_plan.py (99%)
 rename paddleslim/dygraph/{ => prune}/var_group.py (93%)

diff --git a/paddleslim/__init__.py b/paddleslim/__init__.py
index ed7cf9f9..9d312329 100644
--- a/paddleslim/__init__.py
+++ b/paddleslim/__init__.py
@@ -24,3 +24,8 @@ from paddleslim import dygraph
 __all__ = [
     'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph'
 ]
+
+from paddleslim.dygraph import *
+__all__ += dygraph.__all__
+from paddleslim.analysis import *
+__all__ += analysis.__all__
diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py
index 194efade..a55e175a 100644
--- a/paddleslim/analysis/flops.py
+++ b/paddleslim/analysis/flops.py
@@ -18,7 +18,34 @@ from ..core import GraphWrapper, dygraph2program
 __all__ = ["flops", "dygraph_flops"]
 
 
-def flops(program, only_conv=True, detail=False):
+def flops(model, inputs=None, dtypes=None, only_conv=True, detail=False):
+    """
+    Compute the FLOPs of nn.Layer of paddle.Program.
+    Args:
+      model(paddle.nn.Layer|paddle.static.Program): The target model.
+      inputs(list): It is only used when model is instance of 'paddle.nn.Layer'. The dummy inputs used for 'model.forward'. It can be:
+                      1. list<int>|tuple<int>: means 'model.forward' accepts
+                         only one variable as argument and the shape of
+                         variable is 'inputs'.
+                      2. list<list<list>>: means 'model.forward' accepts multiple
+                         variables as arguments and the shapes of variables is 'inputs'.
+                      3. others: 'inputs' will be used as argument list by calling
+                         'model.forward(*inputs)'.
+      dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
+                      data type of each input. None means all the inputs is 'float32'.
+                      Default: None.
+      only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
+                         default: True.
+      detail(bool): Whether to return detail of each convolution layer.
+    """
+    if isinstance(model, paddle.static.Program):
+        return _static_flops(model, only_conv=only_conv, detail=detail)
+    elif isinstance(model, paddle.nn.Layer):
+        return dygraph_flops(
+            model, inputs, dtypes=dtypes, only_conv=only_conv, detail=detail)
+
+
+def _static_flops(program, only_conv=True, detail=False):
     """Get FLOPs of target graph.
 
     Args:
diff --git a/paddleslim/dygraph/__init__.py b/paddleslim/dygraph/__init__.py
index 303f7105..8104f2d1 100644
--- a/paddleslim/dygraph/__init__.py
+++ b/paddleslim/dygraph/__init__.py
@@ -1,24 +1,5 @@
-from . import var_group
-from .var_group import *
-from . import l1norm_pruner
-from .l1norm_pruner import *
-from . import pruner
-from .pruner import *
-from . import filter_pruner
-from .filter_pruner import *
-from . import l2norm_pruner
-from .l2norm_pruner import *
-from . import fpgm_pruner
-from .fpgm_pruner import *
-
 __all__ = []
-
-__all__ += var_group.__all__
-__all__ += l1norm_pruner.__all__
-__all__ += l2norm_pruner.__all__
-__all__ += fpgm_pruner.__all__
-__all__ += pruner.__all__
-__all__ += filter_pruner.__all__
-
 from .quant import *
 __all__ += quant.__all__
+from .prune import *
+__all__ += prune.__all__
diff --git a/paddleslim/dygraph/prune/__init__.py b/paddleslim/dygraph/prune/__init__.py
new file mode 100644
index 00000000..fa2d29e9
--- /dev/null
+++ b/paddleslim/dygraph/prune/__init__.py
@@ -0,0 +1,21 @@
+from . import var_group
+from .var_group import *
+from . import l1norm_pruner
+from .l1norm_pruner import *
+from . import pruner
+from .pruner import *
+from . import filter_pruner
+from .filter_pruner import *
+from . import l2norm_pruner
+from .l2norm_pruner import *
+from . import fpgm_pruner
+from .fpgm_pruner import *
+
+__all__ = []
+
+__all__ += var_group.__all__
+__all__ += l1norm_pruner.__all__
+__all__ += l2norm_pruner.__all__
+__all__ += fpgm_pruner.__all__
+__all__ += pruner.__all__
+__all__ += filter_pruner.__all__
diff --git a/paddleslim/dygraph/filter_pruner.py b/paddleslim/dygraph/prune/filter_pruner.py
similarity index 99%
rename from paddleslim/dygraph/filter_pruner.py
rename to paddleslim/dygraph/prune/filter_pruner.py
index a9c20958..6c97bbb3 100644
--- a/paddleslim/dygraph/filter_pruner.py
+++ b/paddleslim/dygraph/prune/filter_pruner.py
@@ -4,11 +4,11 @@ import numpy as np
 import pickle
 import copy
 import paddle
-from ..common import get_logger
+from paddleslim.common import get_logger
 from .var_group import *
 from .pruning_plan import *
 from .pruner import Pruner
-from ..analysis import dygraph_flops as flops
+from paddleslim.analysis import dygraph_flops as flops
 from .var_group import VarGroup
 
 __all__ = ['Status', 'FilterPruner']
diff --git a/paddleslim/dygraph/fpgm_pruner.py b/paddleslim/dygraph/prune/fpgm_pruner.py
similarity index 97%
rename from paddleslim/dygraph/fpgm_pruner.py
rename to paddleslim/dygraph/prune/fpgm_pruner.py
index 1bff3424..cb825a05 100644
--- a/paddleslim/dygraph/fpgm_pruner.py
+++ b/paddleslim/dygraph/prune/fpgm_pruner.py
@@ -1,7 +1,7 @@
 import logging
 import numpy as np
 import paddle
-from ..common import get_logger
+from paddleslim.common import get_logger
 from .var_group import *
 from .pruning_plan import *
 from .filter_pruner import FilterPruner
diff --git a/paddleslim/dygraph/l1norm_pruner.py b/paddleslim/dygraph/prune/l1norm_pruner.py
similarity index 96%
rename from paddleslim/dygraph/l1norm_pruner.py
rename to paddleslim/dygraph/prune/l1norm_pruner.py
index 9fb2bbb8..358d5fcf 100644
--- a/paddleslim/dygraph/l1norm_pruner.py
+++ b/paddleslim/dygraph/prune/l1norm_pruner.py
@@ -1,7 +1,7 @@
 import logging
 import numpy as np
 import paddle
-from ..common import get_logger
+from paddleslim.common import get_logger
 from .var_group import *
 from .pruning_plan import *
 from .filter_pruner import FilterPruner
diff --git a/paddleslim/dygraph/l2norm_pruner.py b/paddleslim/dygraph/prune/l2norm_pruner.py
similarity index 96%
rename from paddleslim/dygraph/l2norm_pruner.py
rename to paddleslim/dygraph/prune/l2norm_pruner.py
index bffdf3a2..72453923 100644
--- a/paddleslim/dygraph/l2norm_pruner.py
+++ b/paddleslim/dygraph/prune/l2norm_pruner.py
@@ -1,7 +1,7 @@
 import logging
 import numpy as np
 import paddle
-from ..common import get_logger
+from paddleslim.common import get_logger
 from .var_group import *
 from .pruning_plan import *
 from .filter_pruner import FilterPruner
diff --git a/paddleslim/dygraph/pruner.py b/paddleslim/dygraph/prune/pruner.py
similarity index 97%
rename from paddleslim/dygraph/pruner.py
rename to paddleslim/dygraph/prune/pruner.py
index fe107e1d..3d5bfe20 100644
--- a/paddleslim/dygraph/pruner.py
+++ b/paddleslim/dygraph/prune/pruner.py
@@ -3,7 +3,7 @@ import pickle
 import numpy as np
 import logging
 from .pruning_plan import PruningPlan
-from ..common import get_logger
+from paddleslim.common import get_logger
 
 __all__ = ["Pruner"]
 
diff --git a/paddleslim/dygraph/pruning_plan.py b/paddleslim/dygraph/prune/pruning_plan.py
similarity index 99%
rename from paddleslim/dygraph/pruning_plan.py
rename to paddleslim/dygraph/prune/pruning_plan.py
index 185d0194..9aa40e76 100644
--- a/paddleslim/dygraph/pruning_plan.py
+++ b/paddleslim/dygraph/prune/pruning_plan.py
@@ -2,7 +2,7 @@ import paddle
 import collections
 import numpy as np
 import logging
-from ..common import get_logger
+from paddleslim.common import get_logger
 from paddle.fluid import core
 _logger = get_logger(__name__, level=logging.INFO)
 
diff --git a/paddleslim/dygraph/var_group.py b/paddleslim/dygraph/prune/var_group.py
similarity index 93%
rename from paddleslim/dygraph/var_group.py
rename to paddleslim/dygraph/prune/var_group.py
index 894de662..1f9a01ee 100644
--- a/paddleslim/dygraph/var_group.py
+++ b/paddleslim/dygraph/prune/var_group.py
@@ -2,9 +2,9 @@ import numpy as np
 import logging
 import paddle
 from paddle.fluid.dygraph import TracedLayer
-from ..core import GraphWrapper, dygraph2program
-from ..prune import collect_convs
-from ..common import get_logger
+from paddleslim.core import GraphWrapper, dygraph2program
+from paddleslim.prune import collect_convs
+from paddleslim.common import get_logger
 
 __all__ = ["VarGroup"]
 
diff --git a/tests/dygraph/test_flops.py b/tests/dygraph/test_flops.py
index 01ffc451..699d9526 100644
--- a/tests/dygraph/test_flops.py
+++ b/tests/dygraph/test_flops.py
@@ -3,7 +3,7 @@ sys.path.append("../../")
 import unittest
 import numpy as np
 import paddle
-from paddleslim.analysis import dygraph_flops as flops
+from paddleslim import flops
 from paddle.vision.models import mobilenet_v1, resnet50
 from paddle.nn import Conv2D, Layer
 
@@ -16,7 +16,7 @@ class TestFlops(unittest.TestCase):
 
     def runTest(self):
         net = self._net(pretrained=False)
-        FLOPs = flops(net, (1, 3, 32, 32))
+        FLOPs = flops(net, (1, 3, 32, 32), only_conv=False)
         self.assertTrue(FLOPs == self._gt)
 
 
@@ -54,7 +54,7 @@ class TestFLOPsCase1(unittest.TestCase):
             "y": paddle.to_tensor(y),
             "z": "test"
         }
-        FLOPs = flops(net, [inputs])
+        FLOPs = flops(net, [inputs], only_conv=False)
         self.assertTrue(FLOPs == 59184)
 
 
@@ -67,9 +67,10 @@ class TestFLOPsCase2(unittest.TestCase):
         y = np.random.uniform(-1, 1, y_shape).astype('float32')
 
         inputs = [paddle.to_tensor(x), paddle.to_tensor(y)]
-        FLOPs1 = flops(net, inputs)
+        FLOPs1 = flops(net, inputs, only_conv=False)
         shapes = [x_shape, y_shape]
-        FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"])
+        FLOPs2 = flops(
+            net, shapes, dtypes=["float32", "float32"], only_conv=False)
         self.assertTrue(FLOPs1 == FLOPs2)
 
 
diff --git a/tests/dygraph/test_prune.py b/tests/dygraph/test_prune.py
index 64a5b788..6f562751 100644
--- a/tests/dygraph/test_prune.py
+++ b/tests/dygraph/test_prune.py
@@ -16,7 +16,7 @@ sys.path.append("../../")
 import unittest
 import paddle
 import paddle.fluid as fluid
-from paddleslim.dygraph import L1NormFilterPruner
+from paddleslim import L1NormFilterPruner
 from paddle.vision.models import mobilenet_v1, resnet50
 from paddleslim.prune import Pruner
 
diff --git a/tests/test_dygraph_pruning_plan.py b/tests/test_dygraph_pruning_plan.py
index 88c4d59a..fda40b7d 100644
--- a/tests/test_dygraph_pruning_plan.py
+++ b/tests/test_dygraph_pruning_plan.py
@@ -2,7 +2,7 @@ import sys
 sys.path.append("../")
 import unittest
 import numpy as np
-from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask
+from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask
 
 
 class TestPruningPlan(unittest.TestCase):
diff --git a/tests/test_flops.py b/tests/test_flops.py
index f9e4b189..b3eaf8ba 100644
--- a/tests/test_flops.py
+++ b/tests/test_flops.py
@@ -15,7 +15,7 @@ import sys
 sys.path.append("../")
 import unittest
 import paddle.fluid as fluid
-from paddleslim.analysis import flops
+from paddleslim import flops
 from layers import conv_bn_layer
 from static_case import StaticCase
 
-- 
GitLab