未验证 提交 a71ea009 编写于 作者: Y yukavio 提交者: GitHub

add unit test (#29228)

上级 46b73e6c
...@@ -273,6 +273,7 @@ from . import onnx ...@@ -273,6 +273,7 @@ from . import onnx
from .hapi import Model from .hapi import Model
from .hapi import callbacks from .hapi import callbacks
from .hapi import summary from .hapi import summary
from .hapi import flops
import paddle.text import paddle.text
import paddle.vision import paddle.vision
......
...@@ -19,7 +19,9 @@ from . import model_summary ...@@ -19,7 +19,9 @@ from . import model_summary
from . import model from . import model
from .model import * from .model import *
from .model_summary import summary from .model_summary import summary
from .dynamic_flops import flops
logger.setup_logger() logger.setup_logger()
__all__ = ['callbacks'] + model.__all__ + ['summary'] __all__ = ['callbacks'] + model.__all__ + ['summary']
__all__ = model.__all__ + ['flops']
...@@ -16,7 +16,7 @@ import paddle ...@@ -16,7 +16,7 @@ import paddle
import warnings import warnings
import paddle.nn as nn import paddle.nn as nn
import numpy as np import numpy as np
from .static_flops import static_flops, _verify_dependent_package from .static_flops import static_flops
__all__ = ['flops'] __all__ = ['flops']
...@@ -264,7 +264,13 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): ...@@ -264,7 +264,13 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
model.train() model.train()
for handler in handler_collection: for handler in handler_collection:
handler.remove() handler.remove()
_verify_dependent_package()
try:
from prettytable import PrettyTable
except ImportError:
raise ImportError(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
table = PrettyTable( table = PrettyTable(
["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"]) ["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"])
......
...@@ -166,22 +166,15 @@ def count_element_op(op): ...@@ -166,22 +166,15 @@ def count_element_op(op):
return total_ops return total_ops
def _verify_dependent_package(): def _graph_flops(graph, detail=False):
""" assert isinstance(graph, GraphWrapper)
Verify whether `prettytable` is installed. flops = 0
"""
try: try:
from prettytable import PrettyTable from prettytable import PrettyTable
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. " "paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
) )
def _graph_flops(graph, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
_verify_dependent_package()
table = PrettyTable(["OP Type", 'Param name', "Flops"]) table = PrettyTable(["OP Type", 'Param name', "Flops"])
for op in graph.ops(): for op in graph.ops():
param_name = '' param_name = ''
......
...@@ -33,6 +33,8 @@ from paddle.nn.layer.loss import CrossEntropyLoss ...@@ -33,6 +33,8 @@ from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.metric import Accuracy from paddle.metric import Accuracy
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
from paddle.vision.models import LeNet from paddle.vision.models import LeNet
import paddle.vision.models as models
import paddle.fluid.dygraph.jit as jit
from paddle.io import DistributedBatchSampler, Dataset from paddle.io import DistributedBatchSampler, Dataset
from paddle.hapi.model import prepare_distributed_context from paddle.hapi.model import prepare_distributed_context
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
...@@ -564,6 +566,24 @@ class TestModelFunction(unittest.TestCase): ...@@ -564,6 +566,24 @@ class TestModelFunction(unittest.TestCase):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 1, 2)) paddle.summary(nlp_net, (1, 1, 2))
def test_static_flops(self):
paddle.disable_static()
net = models.__dict__['mobilenet_v2'](pretrained=False)
inputs = paddle.randn([1, 3, 224, 224])
static_program = jit._trace(net, inputs=[inputs])[1]
paddle.flops(static_program, [1, 3, 224, 224], print_detail=True)
def test_dynamic_flops(self):
net = models.__dict__['mobilenet_v2'](pretrained=False)
def customize_dropout(m, x, y):
m.total_ops += 0
paddle.flops(
net, [1, 3, 224, 224],
custom_ops={paddle.nn.Dropout: customize_dropout},
print_detail=True)
def test_export_deploy_model(self): def test_export_deploy_model(self):
self.set_seed() self.set_seed()
np.random.seed(201) np.random.seed(201)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册