未验证 提交 5f323e61 编写于 作者: Y yukavio 提交者: GitHub

add flops api unit test (#29228) (#29278)

上级 32c139d3
......@@ -273,6 +273,7 @@ from . import onnx
from .hapi import Model
from .hapi import callbacks
from .hapi import summary
from .hapi import flops
import paddle.text
import paddle.vision
......
......@@ -19,7 +19,9 @@ from . import model_summary
from . import model
from .model import *
from .model_summary import summary
from .dynamic_flops import flops
logger.setup_logger()
__all__ = ['callbacks'] + model.__all__ + ['summary']
__all__ = model.__all__ + ['flops']
......@@ -16,7 +16,7 @@ import paddle
import warnings
import paddle.nn as nn
import numpy as np
from .static_flops import static_flops, _verify_dependent_package
from .static_flops import static_flops
__all__ = ['flops']
......@@ -264,7 +264,13 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
model.train()
for handler in handler_collection:
handler.remove()
_verify_dependent_package()
try:
from prettytable import PrettyTable
except ImportError:
raise ImportError(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
table = PrettyTable(
["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"])
......
......@@ -166,22 +166,15 @@ def count_element_op(op):
return total_ops
def _verify_dependent_package():
"""
Verify whether `prettytable` is installed.
"""
def _graph_flops(graph, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
try:
from prettytable import PrettyTable
except ImportError:
raise ImportError(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
def _graph_flops(graph, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
_verify_dependent_package()
table = PrettyTable(["OP Type", 'Param name', "Flops"])
for op in graph.ops():
param_name = ''
......
......@@ -33,6 +33,8 @@ from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.metric import Accuracy
from paddle.vision.datasets import MNIST
from paddle.vision.models import LeNet
import paddle.vision.models as models
import paddle.fluid.dygraph.jit as jit
from paddle.io import DistributedBatchSampler, Dataset
from paddle.hapi.model import prepare_distributed_context
from paddle.fluid.dygraph.jit import declarative
......@@ -564,6 +566,24 @@ class TestModelFunction(unittest.TestCase):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 1, 2))
def test_static_flops(self):
paddle.disable_static()
net = models.__dict__['mobilenet_v2'](pretrained=False)
inputs = paddle.randn([1, 3, 224, 224])
static_program = jit._trace(net, inputs=[inputs])[1]
paddle.flops(static_program, [1, 3, 224, 224], print_detail=True)
def test_dynamic_flops(self):
net = models.__dict__['mobilenet_v2'](pretrained=False)
def customize_dropout(m, x, y):
m.total_ops += 0
paddle.flops(
net, [1, 3, 224, 224],
custom_ops={paddle.nn.Dropout: customize_dropout},
print_detail=True)
def test_export_deploy_model(self):
self.set_seed()
np.random.seed(201)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册