未验证 提交 2b780c77 编写于 作者: W whs 提交者: GitHub

Move pruning and quant API from paddleslim.dygraph to paddleslim (#633) (#642)

上级 a86edf5a
......@@ -24,3 +24,8 @@ from paddleslim import dygraph
__all__ = [
'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon', 'dygraph'
]
from paddleslim.dygraph import *
__all__ += dygraph.__all__
from paddleslim.analysis import *
__all__ += analysis.__all__
......@@ -18,7 +18,34 @@ from ..core import GraphWrapper, dygraph2program
__all__ = ["flops", "dygraph_flops"]
def flops(program, only_conv=True, detail=False):
def flops(model, inputs=None, dtypes=None, only_conv=True, detail=False):
"""
Compute the FLOPs of nn.Layer of paddle.Program.
Args:
model(paddle.nn.Layer|paddle.static.Program): The target model.
inputs(list): It is only used when model is instance of 'paddle.nn.Layer'. The dummy inputs used for 'model.forward'. It can be:
1. list<int>|tuple<int>: means 'model.forward' accepts
only one variable as argument and the shape of
variable is 'inputs'.
2. list<list<list>>: means 'model.forward' accepts multiple
variables as arguments and the shapes of variables is 'inputs'.
3. others: 'inputs' will be used as argument list by calling
'model.forward(*inputs)'.
dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
data type of each input. None means all the inputs is 'float32'.
Default: None.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer.
"""
if isinstance(model, paddle.static.Program):
return _static_flops(model, only_conv=only_conv, detail=detail)
elif isinstance(model, paddle.nn.Layer):
return dygraph_flops(
model, inputs, dtypes=dtypes, only_conv=only_conv, detail=detail)
def _static_flops(program, only_conv=True, detail=False):
"""Get FLOPs of target graph.
Args:
......
from . import var_group
from .var_group import *
from . import l1norm_pruner
from .l1norm_pruner import *
from . import pruner
from .pruner import *
from . import filter_pruner
from .filter_pruner import *
from . import l2norm_pruner
from .l2norm_pruner import *
from . import fpgm_pruner
from .fpgm_pruner import *
__all__ = []
__all__ += var_group.__all__
__all__ += l1norm_pruner.__all__
__all__ += l2norm_pruner.__all__
__all__ += fpgm_pruner.__all__
__all__ += pruner.__all__
__all__ += filter_pruner.__all__
from .quant import *
__all__ += quant.__all__
from .prune import *
__all__ += prune.__all__
from . import var_group
from .var_group import *
from . import l1norm_pruner
from .l1norm_pruner import *
from . import pruner
from .pruner import *
from . import filter_pruner
from .filter_pruner import *
from . import l2norm_pruner
from .l2norm_pruner import *
from . import fpgm_pruner
from .fpgm_pruner import *
__all__ = []
__all__ += var_group.__all__
__all__ += l1norm_pruner.__all__
__all__ += l2norm_pruner.__all__
__all__ += fpgm_pruner.__all__
__all__ += pruner.__all__
__all__ += filter_pruner.__all__
......@@ -4,11 +4,11 @@ import numpy as np
import pickle
import copy
import paddle
from ..common import get_logger
from paddleslim.common import get_logger
from .var_group import *
from .pruning_plan import *
from .pruner import Pruner
from ..analysis import dygraph_flops as flops
from paddleslim.analysis import dygraph_flops as flops
from .var_group import VarGroup
__all__ = ['Status', 'FilterPruner']
......
import logging
import numpy as np
import paddle
from ..common import get_logger
from paddleslim.common import get_logger
from .var_group import *
from .pruning_plan import *
from .filter_pruner import FilterPruner
......
import logging
import numpy as np
import paddle
from ..common import get_logger
from paddleslim.common import get_logger
from .var_group import *
from .pruning_plan import *
from .filter_pruner import FilterPruner
......
import logging
import numpy as np
import paddle
from ..common import get_logger
from paddleslim.common import get_logger
from .var_group import *
from .pruning_plan import *
from .filter_pruner import FilterPruner
......
......@@ -3,7 +3,7 @@ import pickle
import numpy as np
import logging
from .pruning_plan import PruningPlan
from ..common import get_logger
from paddleslim.common import get_logger
__all__ = ["Pruner"]
......
......@@ -2,7 +2,7 @@ import paddle
import collections
import numpy as np
import logging
from ..common import get_logger
from paddleslim.common import get_logger
from paddle.fluid import core
_logger = get_logger(__name__, level=logging.INFO)
......
......@@ -2,9 +2,9 @@ import numpy as np
import logging
import paddle
from paddle.fluid.dygraph import TracedLayer
from ..core import GraphWrapper, dygraph2program
from ..prune import collect_convs
from ..common import get_logger
from paddleslim.core import GraphWrapper, dygraph2program
from paddleslim.prune import collect_convs
from paddleslim.common import get_logger
__all__ = ["VarGroup"]
......
......@@ -3,7 +3,7 @@ sys.path.append("../../")
import unittest
import numpy as np
import paddle
from paddleslim.analysis import dygraph_flops as flops
from paddleslim import flops
from paddle.vision.models import mobilenet_v1, resnet50
from paddle.nn import Conv2D, Layer
......@@ -16,7 +16,7 @@ class TestFlops(unittest.TestCase):
def runTest(self):
net = self._net(pretrained=False)
FLOPs = flops(net, (1, 3, 32, 32))
FLOPs = flops(net, (1, 3, 32, 32), only_conv=False)
self.assertTrue(FLOPs == self._gt)
......@@ -54,7 +54,7 @@ class TestFLOPsCase1(unittest.TestCase):
"y": paddle.to_tensor(y),
"z": "test"
}
FLOPs = flops(net, [inputs])
FLOPs = flops(net, [inputs], only_conv=False)
self.assertTrue(FLOPs == 59184)
......@@ -67,9 +67,10 @@ class TestFLOPsCase2(unittest.TestCase):
y = np.random.uniform(-1, 1, y_shape).astype('float32')
inputs = [paddle.to_tensor(x), paddle.to_tensor(y)]
FLOPs1 = flops(net, inputs)
FLOPs1 = flops(net, inputs, only_conv=False)
shapes = [x_shape, y_shape]
FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"])
FLOPs2 = flops(
net, shapes, dtypes=["float32", "float32"], only_conv=False)
self.assertTrue(FLOPs1 == FLOPs2)
......
......@@ -16,7 +16,7 @@ sys.path.append("../../")
import unittest
import paddle
import paddle.fluid as fluid
from paddleslim.dygraph import L1NormFilterPruner
from paddleslim import L1NormFilterPruner
from paddle.vision.models import mobilenet_v1, resnet50
from paddleslim.prune import Pruner
......
......@@ -2,7 +2,7 @@ import sys
sys.path.append("../")
import unittest
import numpy as np
from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask
from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase):
......
......@@ -15,7 +15,7 @@ import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.analysis import flops
from paddleslim import flops
from layers import conv_bn_layer
from static_case import StaticCase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部