From c12b6faf537727e0413a43c860e1e2b18b3682b4 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 6 Nov 2019 20:42:16 +0800 Subject: [PATCH] Add flops function --- .gitignore | 2 + paddleslim/analysis/__init__.py | 4 ++ paddleslim/analysis/flops.py | 68 ++++++++++++++++++++++++++++++ paddleslim/core/__init__.pyc | Bin 219 -> 0 bytes paddleslim/core/graph_wrapper.pyc | Bin 14549 -> 0 bytes tests/test_flops.py | 40 ++++++++++++++++++ 6 files changed, 114 insertions(+) create mode 100644 paddleslim/analysis/flops.py delete mode 100644 paddleslim/core/__init__.pyc delete mode 100644 paddleslim/core/graph_wrapper.pyc create mode 100644 tests/test_flops.py diff --git a/.gitignore b/.gitignore index 2ea48a8b..c59240cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.egg-info build/ ./dist/ +*.pyc +dist/ diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index 9d053150..b9fcfbfe 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import flops as flops_module +from flops import * +__all__ = [] +__all__ += flops_module.__all__ diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py new file mode 100644 index 00000000..583c8e6e --- /dev/null +++ b/paddleslim/analysis/flops.py @@ -0,0 +1,68 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ..core import GraphWrapper + +__all__ = ["flops"] + + +def flops(program): + """ + Get FLOPS of target graph. + Args: + program(Program): The program used to calculate FLOPS. + """ + graph = GraphWrapper(program) + return _graph_flops(graph) + + +def _graph_flops(graph, only_conv=False): + assert isinstance(graph, GraphWrapper) + flops = 0 + for op in graph.ops(): + if op.type() in ['conv2d', 'depthwise_conv2d']: + filter_shape = op.inputs("Filter")[0].shape() + input_shape = op.inputs("Input")[0].shape() + output_shape = op.outputs("Output")[0].shape() + c_out, c_in, k_h, k_w = filter_shape + _, _, h_out, w_out = output_shape + groups = op.attr("groups") + kernel_ops = k_h * k_w * (c_in / groups) + if len(op.inputs("Bias")) > 0: + with_bias = 1 + else: + with_bias = 0 + flops += 2 * h_out * w_out * c_out * (kernel_ops + with_bias) + elif op.type() == 'pool2d' and not only_conv: + input_shape = op.inputs("X")[0].shape() + output_shape = op.outputs("Out")[0].shape() + _, c_out, h_out, w_out = output_shape + k_size = op.attr("ksize") + flops += h_out * w_out * c_out * (k_size[0]**2) + + elif op.type() == 'mul' and not only_conv: + x_shape = list(op.inputs("X")[0].shape()) + y_shape = op.inputs("Y")[0].shape() + if x_shape[0] == -1: + x_shape[0] = 1 + flops += 2 * x_shape[0] * x_shape[1] * y_shape[1] + + elif op.type() in ['relu', 'sigmoid', 'batch_norm'] and not only_conv: + input_shape = list(op.inputs("X")[0].shape()) + if input_shape[0] == -1: + input_shape[0] = 1 + flops += np.product(input_shape) + + return flops diff --git a/paddleslim/core/__init__.pyc b/paddleslim/core/__init__.pyc deleted file mode 100644 index e532c8d43acacca16231998b9717c3d81d7b94d8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 219 zcmYL@Q3}E^42F|VPzFPf;G4yV?E)g6zz6kdAYrUUEA3X+ig-5PJc?KF0;VV!k}t`h z{~_7C$Y0OLv4vj6Jo(Q4UIq3 zjp_C9mD6pxm07ZI!7d?=8AOe>L^26~ggLKeS@}jOtAjPJYJ2DUt|Bs>iA1Hp#|fKO Ru2)y9*HfNDdS-?>egJd{CuINt diff --git a/paddleslim/core/graph_wrapper.pyc b/paddleslim/core/graph_wrapper.pyc deleted file mode 100644 index 591351096db4cc217efe7320165cd2f13ca1ddb2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14549 zcmds8O>`Z}Rjz(ddeYO6B}=wv95R_6&y4+MU^F2^aAM44Y-`5jz{sXOiDv{((|X;K z+UnQ++||#L5{Z+81Du>>6|w*ptXZ(m1`dlXSg>ae>|h6nB@6hzTmAQ>XV2)_;T*8s zrRwUg>bm#Vz4yDfO3nYApIQBfzq#L4`M)-PzlTfzD+(X~Efp!XUe+zO-YD9eYQ0(1 zC)E0c)Eg?ARK1p3Zz+rYn{q#;)~C$tBudUYGsyCz7XH;)it2l*$xk z#a`3Ox!d8KyG<0ub%S4c4@IWbfu%A_J#Hv{R%H$Khqy&!LuE~Ad^2yvYX?fX6MV!9 zzNsKf&UdXt`h}~Z4r85paW`VAQ6x+)Si@N z^4E2Phpzh;YMNhB6J3s1TtX^!pd{7EQNNA0a(*|141@(XCuS#`Qini9rAATVd;z^> zJvgLboVm*o@1XOEX8q++v)0>Qe`wkpUKEU)wk$At`x)o{5~lI>+T4qU#oxm57f>lV z{s3A8t+QZ+X%hL89wOv!&i-SB(4kqPuM^2Cp(pX<92c!Z= zq0X+BR{z~A_BCnZN-w`r4Iv)LIg43|VM_X1ycTBVny+c&1K?g_--v`t;MnXzrX~Gx z7%zi{F{sh+V-QWHRjrnF&RVk0SfbPv0k0dl^z#ngfhs)R6-Wu?aAfKpHiGK@{LDv1qYa<{(`4 zZqgU6tcjfxcV0wM@%IF*rPy3tn%^N%csfjY43jP}0hM16OrRiz*3jRp0K(V{Th=5v zJ7mjgp(b$MVEY>=V9QHawR-i;yCXN7D zDoku;9*l+Mw*@#$N7qZYT&kBAAv^S-6$?cDFtofemavH0Y4|QkBQ`BaBS}qMrSW74RX3i-Sc(Z16r7t-fxR@L%LV0E(6JP39cs~d zGI)Vh_XPqQn@Y1-$+l)a=_=)wn$6@CO$Z=hWNcgU8pinvuIEqwZ(!8VM}BTcpGZF) zw|VCmz<`Sw{Lg9(AY+T`2HS^%JVF{sx=|1EevY?FLW>~pkrxdlTJ^FF(jFkE0AZIm zm$XFTD>G%sYw6aJcxETq@8~Rj|T5oGc*~DG9L6OUFlkZhPOfU;(C1%Hb6EB_^hfkwVG5BAyPL)&B$8$H8 zQ8Bq2v6=J<`rlEMoGa{pNfbG47E}4oaqjYkr?9u|IWow30;{jsAJ>^{w2-h0g1%>d%Hkasv?CG&In0SBl6@6(zwjr`wuY5 zGkMB={VXFZA_V;ig{+NR8C||-=V?!3Fq0Ww*^CCEzp}}s_->N!tlUa(GSh0)+1na9 zZd><*Zs0u)%wU=!Y-U&OY6N-QB@gLl_*<;rKcj-64z3(-9rn&7dMwyKFmq>G zr5z?YA1f70_<)y?KMn;1xUHfxrS{$tSTLy`O!MW{`xn$963!FqL0hGFSs`j@Nz077 z`tYv8|9kiZ2b`s?%|G{3%8}n{Nb(68b2E*m6QvSMcmo}lJQ zk{Jk_sUTb4G*CORCFNb|JqJ7bk7fpjn5zW&vh*!K*bL(Uwj=iK?Lp6r-%0~)N+7Q| z!dr|Q%y$jN;BSZX0Wy*%E5Y1>LtWk%ZRTdIqqjFL=Etdap(N7a#s@Y}hOURtGk*I82FW`C+ zKMVNj(jn7q5T)_6dJP4Xm6_vfqTUkotloczn!EQVHPl#>)%#zh(!!=72LU!8pf1J+ z8f!hgsO|wv*D}ZqKuQ&>95>J`$g^sc{aj_ipiBWIv-$U$7Ug#M;uvc zo_TPKPhaH&Tqib)Hv0SrE~9A^E$G=IHI0m;uxN&5HNAlf7zcZ17Sv;cpR+WIP&+I$ zW9XPKSS<1doKVS@x_v8@uWo^wc-awC{~l{oD4xgipyuIo`oHkoEQ*OzQ}{`tjhh}! z_=~0(Q(h*K4PCkV{2E>yqKN3CpoUReSsxYqlO`ci^mr>uU^HAYmq*gZW}g~wpIDr< zEEZ&Ul+;p;q@YeJBnRh&K=^?{u9L&JE#5wE783nFdYAOyh=3t%&yIs#1MvXDlb+oj zq$y6O97TO2Jyu7+e(tYwn_7zzYGc>5L)b{?Yn zQ^Tc^#E4U>yLPkMhYsrlx;`z-I5)b5AQx*fp2P2Pntr z40p{hiUkGkkuZgzr^G$w5tX0g8$)a;wru-p@F+|M)!vKQjq%bg?CTcBwBpKfCyg!a za=)C*=}6D2>&VqO&VIWMTAbHeyu#uFiW+o%m2EZ(IV{r`s08Q>D>2Th%b|LJ1h0O9 zLge==Pu|KQ7~>9SCrEj zg#4i_FFsM0lw&E&KEErF4y);L?J!)fuSPkcchpxXCqfXsEtVZ{938Oe7^)ySkt=6M zJ?yA^pvB}En1Ps3Ix3FSQDJEyG~M!$gz5H55GoSkr|6GlWuqXj0Ugn!5Bb8_(3wXaVMYvgb~pCH_^*=h#E)v?tVJ#2i4|xOnum8x znEV5C7y!OiH;OMy>IUu;C>tP1ww}R#xXjCVw+!}#hsXAxB{4DRZE%Zd(+;A5=Q$!f%>fJ!{*1$hvt zuaRYxo)W0|k})}yAI;7UT>8&JE|~=n^5ceaf8fdL{Xb|hG{$no;>86p+wV=N!zpz*tq$Aja7G=p6kBHH|C~CU zm(IwRPOA@Vj&*PfHwV)0b)eC1XgLp{4B6gN1FVa6{wsDTX3~9*T4bfbZHh<&+rc@FF>gt4? z5<*UMGMGs@5gcv?%P_s=#fX4v;M7?_>NqCyN1Nl0VC|u1}s4xC8-`I`if4$(N-m|1#hm zvaxsz6y%>0U8Fsu{m_4pqO?r~4#NghUz}8d!j7a|Mxc#2+IPE2{OF2(8&&LlMBPE; zWnsJ}X-7V?BOI*%S?v6tC=0G0Jj)x%!!8e1nur5u`-A~I?soorR~X^Sm$v0acG+dk z>7pil*a^}Yc9rKeIX`ZAxEtWS@O#G-Q#zcD`*8SrZ zdV$7*CaEv#nAlkmYaYr)oTU*M)>P2!PGR2x`wg2=5*u61(yN~ixCv1kULstqmRj#M zgKOxKW$}sBy)^aA(SGb>_(X{&C(h+e6k;Y*9l5Nqj&zp^ir1FqY2F@R_8E^Q4uo+z z{t?#gaWzN)JGpFKYILc+)c7VYk?3!uviE*nqL<~E9fW6`H~?c)rSPf(fZ)2BMKE>j zkmV?Io~Q8cVW;x0Xt@3bh*FGp%7=ncH(dEm#c^!5}>Cpq~$~BAS!G;8tgR*h$Xs-jpOr(7N~?GIQ|U{SIL@K>5M1L zk9rt6vL&X?`F%X*-10EkM$K%k%Qkvxzn=r2WG4BlHfzUCHpS)rh|Cj|3Q!NDX03H3 z?|zAsq45-_t;=CPNI!UEC-(u|^i%(hH zXF)^j{1FQZ()l5Z%3w?F1Un&Q{LgS{a;`PeYAl^^&&k<=8J-@PW4VBzxpsSQr9IPb z%g(!V2XD)S0p5`Vqf<0~JnZh^&&{aiiP^U6CQ%f~Uj=E|o^xmgGO}y*4!vv|Iy4a` zVxwov*~z^k*ZB$?IPGNP0S1;#u0vtS4&>OJkeuCBb9Vh*26Or%iaE>tEW#^YXe>0F O*6>ez;@sEHeeyq`X4W47 diff --git a/tests/test_flops.py b/tests/test_flops.py new file mode 100644 index 00000000..cd16b861 --- /dev/null +++ b/tests/test_flops.py @@ -0,0 +1,40 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.analysis import flops +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + self.assertTrue(1597440 == flops(main_program)) + + +if __name__ == '__main__': + unittest.main() -- GitLab