未验证 提交 2b780c77 编写于 作者: W whs 提交者: GitHub

Move pruning and quant API from paddleslim.dygraph to paddleslim (#633) (#642)

上级 a86edf5a
...@@ -24,3 +24,8 @@ from paddleslim import dygraph ...@@ -24,3 +24,8 @@ from paddleslim import dygraph
__all__ = [ __all__ = [
'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph' 'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph'
] ]
from paddleslim.dygraph import *
__all__ += dygraph.__all__
from paddleslim.analysis import *
__all__ += analysis.__all__
...@@ -18,7 +18,34 @@ from ..core import GraphWrapper, dygraph2program ...@@ -18,7 +18,34 @@ from ..core import GraphWrapper, dygraph2program
__all__ = ["flops", "dygraph_flops"] __all__ = ["flops", "dygraph_flops"]
def flops(program, only_conv=True, detail=False): def flops(model, inputs=None, dtypes=None, only_conv=True, detail=False):
"""
Compute the FLOPs of nn.Layer of paddle.Program.
Args:
model(paddle.nn.Layer|paddle.static.Program): The target model.
inputs(list): It is only used when model is instance of 'paddle.nn.Layer'. The dummy inputs used for 'model.forward'. It can be:
1. list<int>|tuple<int>: means 'model.forward' accepts
only one variable as argument and the shape of
variable is 'inputs'.
2. list<list<list>>: means 'model.forward' accepts multiple
variables as arguments and the shapes of variables is 'inputs'.
3. others: 'inputs' will be used as argument list by calling
'model.forward(*inputs)'.
dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
data type of each input. None means all the inputs is 'float32'.
Default: None.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer.
"""
if isinstance(model, paddle.static.Program):
return _static_flops(model, only_conv=only_conv, detail=detail)
elif isinstance(model, paddle.nn.Layer):
return dygraph_flops(
model, inputs, dtypes=dtypes, only_conv=only_conv, detail=detail)
def _static_flops(program, only_conv=True, detail=False):
"""Get FLOPs of target graph. """Get FLOPs of target graph.
Args: Args:
......
from . import var_group
from .var_group import *
from . import l1norm_pruner
from .l1norm_pruner import *
from . import pruner
from .pruner import *
from . import filter_pruner
from .filter_pruner import *
from . import l2norm_pruner
from .l2norm_pruner import *
from . import fpgm_pruner
from .fpgm_pruner import *
__all__ = [] __all__ = []
__all__ += var_group.__all__
__all__ += l1norm_pruner.__all__
__all__ += l2norm_pruner.__all__
__all__ += fpgm_pruner.__all__
__all__ += pruner.__all__
__all__ += filter_pruner.__all__
from .quant import * from .quant import *
__all__ += quant.__all__ __all__ += quant.__all__
from .prune import *
__all__ += prune.__all__
from . import var_group
from .var_group import *
from . import l1norm_pruner
from .l1norm_pruner import *
from . import pruner
from .pruner import *
from . import filter_pruner
from .filter_pruner import *
from . import l2norm_pruner
from .l2norm_pruner import *
from . import fpgm_pruner
from .fpgm_pruner import *
__all__ = []
__all__ += var_group.__all__
__all__ += l1norm_pruner.__all__
__all__ += l2norm_pruner.__all__
__all__ += fpgm_pruner.__all__
__all__ += pruner.__all__
__all__ += filter_pruner.__all__
...@@ -4,11 +4,11 @@ import numpy as np ...@@ -4,11 +4,11 @@ import numpy as np
import pickle import pickle
import copy import copy
import paddle import paddle
from ..common import get_logger from paddleslim.common import get_logger
from .var_group import * from .var_group import *
from .pruning_plan import * from .pruning_plan import *
from .pruner import Pruner from .pruner import Pruner
from ..analysis import dygraph_flops as flops from paddleslim.analysis import dygraph_flops as flops
from .var_group import VarGroup from .var_group import VarGroup
__all__ = ['Status', 'FilterPruner'] __all__ = ['Status', 'FilterPruner']
......
import logging import logging
import numpy as np import numpy as np
import paddle import paddle
from ..common import get_logger from paddleslim.common import get_logger
from .var_group import * from .var_group import *
from .pruning_plan import * from .pruning_plan import *
from .filter_pruner import FilterPruner from .filter_pruner import FilterPruner
......
import logging import logging
import numpy as np import numpy as np
import paddle import paddle
from ..common import get_logger from paddleslim.common import get_logger
from .var_group import * from .var_group import *
from .pruning_plan import * from .pruning_plan import *
from .filter_pruner import FilterPruner from .filter_pruner import FilterPruner
......
import logging import logging
import numpy as np import numpy as np
import paddle import paddle
from ..common import get_logger from paddleslim.common import get_logger
from .var_group import * from .var_group import *
from .pruning_plan import * from .pruning_plan import *
from .filter_pruner import FilterPruner from .filter_pruner import FilterPruner
......
...@@ -3,7 +3,7 @@ import pickle ...@@ -3,7 +3,7 @@ import pickle
import numpy as np import numpy as np
import logging import logging
from .pruning_plan import PruningPlan from .pruning_plan import PruningPlan
from ..common import get_logger from paddleslim.common import get_logger
__all__ = ["Pruner"] __all__ = ["Pruner"]
......
...@@ -2,7 +2,7 @@ import paddle ...@@ -2,7 +2,7 @@ import paddle
import collections import collections
import numpy as np import numpy as np
import logging import logging
from ..common import get_logger from paddleslim.common import get_logger
from paddle.fluid import core from paddle.fluid import core
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
......
...@@ -2,9 +2,9 @@ import numpy as np ...@@ -2,9 +2,9 @@ import numpy as np
import logging import logging
import paddle import paddle
from paddle.fluid.dygraph import TracedLayer from paddle.fluid.dygraph import TracedLayer
from ..core import GraphWrapper, dygraph2program from paddleslim.core import GraphWrapper, dygraph2program
from ..prune import collect_convs from paddleslim.prune import collect_convs
from ..common import get_logger from paddleslim.common import get_logger
__all__ = ["VarGroup"] __all__ = ["VarGroup"]
......
...@@ -3,7 +3,7 @@ sys.path.append("../../") ...@@ -3,7 +3,7 @@ sys.path.append("../../")
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddleslim.analysis import dygraph_flops as flops from paddleslim import flops
from paddle.vision.models import mobilenet_v1, resnet50 from paddle.vision.models import mobilenet_v1, resnet50
from paddle.nn import Conv2D, Layer from paddle.nn import Conv2D, Layer
...@@ -16,7 +16,7 @@ class TestFlops(unittest.TestCase): ...@@ -16,7 +16,7 @@ class TestFlops(unittest.TestCase):
def runTest(self): def runTest(self):
net = self._net(pretrained=False) net = self._net(pretrained=False)
FLOPs = flops(net, (1, 3, 32, 32)) FLOPs = flops(net, (1, 3, 32, 32), only_conv=False)
self.assertTrue(FLOPs == self._gt) self.assertTrue(FLOPs == self._gt)
...@@ -54,7 +54,7 @@ class TestFLOPsCase1(unittest.TestCase): ...@@ -54,7 +54,7 @@ class TestFLOPsCase1(unittest.TestCase):
"y": paddle.to_tensor(y), "y": paddle.to_tensor(y),
"z": "test" "z": "test"
} }
FLOPs = flops(net, [inputs]) FLOPs = flops(net, [inputs], only_conv=False)
self.assertTrue(FLOPs == 59184) self.assertTrue(FLOPs == 59184)
...@@ -67,9 +67,10 @@ class TestFLOPsCase2(unittest.TestCase): ...@@ -67,9 +67,10 @@ class TestFLOPsCase2(unittest.TestCase):
y = np.random.uniform(-1, 1, y_shape).astype('float32') y = np.random.uniform(-1, 1, y_shape).astype('float32')
inputs = [paddle.to_tensor(x), paddle.to_tensor(y)] inputs = [paddle.to_tensor(x), paddle.to_tensor(y)]
FLOPs1 = flops(net, inputs) FLOPs1 = flops(net, inputs, only_conv=False)
shapes = [x_shape, y_shape] shapes = [x_shape, y_shape]
FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"]) FLOPs2 = flops(
net, shapes, dtypes=["float32", "float32"], only_conv=False)
self.assertTrue(FLOPs1 == FLOPs2) self.assertTrue(FLOPs1 == FLOPs2)
......
...@@ -16,7 +16,7 @@ sys.path.append("../../") ...@@ -16,7 +16,7 @@ sys.path.append("../../")
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.dygraph import L1NormFilterPruner from paddleslim import L1NormFilterPruner
from paddle.vision.models import mobilenet_v1, resnet50 from paddle.vision.models import mobilenet_v1, resnet50
from paddleslim.prune import Pruner from paddleslim.prune import Pruner
......
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import numpy as np import numpy as np
from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase): class TestPruningPlan(unittest.TestCase):
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.analysis import flops from paddleslim import flops
from layers import conv_bn_layer from layers import conv_bn_layer
from static_case import StaticCase from static_case import StaticCase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册