From d81fbe8d185bb830761c584b2a32de44a3154a94 Mon Sep 17 00:00:00 2001 From: SPC Date: Mon, 22 Jun 2020 10:32:56 +0800 Subject: [PATCH] Add roialign mlukernel (#103) 1add roi align mlu kernel 2add LITE_BUILD_EXTRA when building test for mlu roialign kernels 3modify lite.cmake to fix compiling error when supporting mlu kernelxx.o exists --- cmake/lite.cmake | 12 +- lite/kernels/mlu/CMakeLists.txt | 6 + .../kernels/mlu/mlu_kernel/roi_align_kernel.o | Bin 0 -> 347944 bytes lite/kernels/mlu/roi_align_compute.cc | 198 ++++++++++++++++++ lite/kernels/mlu/roi_align_compute.h | 49 +++++ lite/kernels/mlu/roi_align_compute_test.cc | 132 ++++++++++++ lite/kernels/mlu/roi_align_kernel.h | 78 +++++++ lite/kernels/x86/roi_align_compute.cc | 2 + 8 files changed, 475 insertions(+), 2 deletions(-) create mode 100644 lite/kernels/mlu/mlu_kernel/roi_align_kernel.o create mode 100644 lite/kernels/mlu/roi_align_compute.cc create mode 100644 lite/kernels/mlu/roi_align_compute.h create mode 100644 lite/kernels/mlu/roi_align_compute_test.cc create mode 100644 lite/kernels/mlu/roi_align_kernel.h diff --git a/cmake/lite.cmake b/cmake/lite.cmake index 9a633409cd..66abeab483 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -413,7 +413,9 @@ function(add_kernel TARGET device level) if ("${device}" STREQUAL "MLU") if (NOT LITE_WITH_MLU) foreach(src ${args_SRCS}) - file(APPEND ${fake_kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + if (NOT (src MATCHES ".*\\.o")) + file(APPEND ${fake_kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endif() endforeach() return() endif() @@ -446,7 +448,13 @@ function(add_kernel TARGET device level) # the source list will collect for paddle_use_kernel.h code generation. foreach(src ${args_SRCS}) - file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + if (LITE_WITH_MLU) + if (NOT (src MATCHES ".*\\.o")) + file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endif() + else() + file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endif() endforeach() lite_cc_library(${TARGET} SRCS ${args_SRCS} diff --git a/lite/kernels/mlu/CMakeLists.txt b/lite/kernels/mlu/CMakeLists.txt index 5557f86c58..eed70d7fa6 100644 --- a/lite/kernels/mlu/CMakeLists.txt +++ b/lite/kernels/mlu/CMakeLists.txt @@ -7,5 +7,11 @@ add_kernel(subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS ${lite_k add_kernel(io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} ${target_wrapper_mlu}) add_kernel(calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS ${lite_kernel_deps}) # depend on transpose function in backend/x86/math/math_function +add_kernel(roi_align_compute_mlu MLU extra SRCS roi_align_compute.cc mlu_kernel/roi_align_kernel.o DEPS ${lite_kernel_deps}) + +if(LITE_BUILD_EXTRA) + lite_cc_test(test_roi_align_compute_mlu SRCS roi_align_compute_test.cc DEPS roi_align_compute_mlu) +endif() + add_kernel(layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} ${math_function} ${target_wrapper_mlu}) add_kernel(cast_compute_mlu MLU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} ${target_wrapper_mlu}) diff --git a/lite/kernels/mlu/mlu_kernel/roi_align_kernel.o b/lite/kernels/mlu/mlu_kernel/roi_align_kernel.o new file mode 100644 index 0000000000000000000000000000000000000000..d43056577e894203a0cfdd0f612478e22b7718e1 GIT binary patch literal 347944 zcmeHwe~cf;b>H0aNS#_LJw$YM6p!MbBqvOvNatF%kGANpE{Z}cjB<{ohH_--D~VK& z7A2n~Dmm9~HIiAEg7wGM)gMI5)?YVjz&38crfNzc4En9Z$}Dg&*osIPfEB(8jFxbP zf{lq*aOtSKJM(sS-p+n^ek}LyE50wmo7vf!d2c@R-kUeGv$HS#{O)^4M@H5$`5R%M zW+QM>2rvB7`ur+pvq?N=>(iXKzp(Ja!RHS?{sg=I$!DMZ;@*8pioAZ|OV6a~+OPME z>yvVfb!BSrC)91A`{mzR-%nHH|GO=}IhMx?5q_OwDoNhBut~cfTiC2!k1u>kyDrpZ z7v(nsiz2_E&)>gZTyDQ0n)d(X^{wI>u#LwrTpL24|Jkd2Tw@fqpBbUD|EPA{f9Q7YWmx&%pJNLgvzfS%xO)YEkb^Y{mst-AN5BbyRv)(VilDAF2 zV|CU0k7oQ-VWa-!w{iRO8@7?}xbZI~pELHC@87bN5lNJw=4Vn6QD`KoXsoupUVZlu z-WMfCv-HuCYeAXyeAZW@9>_iz??iO2TVJ3TO(yoEa1FH>LomAO%_Goi{?#$t^2SOT z)A_H+q>|SJe&e$F9T)sm8*EOz~$+uvaQ z-R5q83;bsC7ovYjU85o8zY{7y_pdSeG4%hI6!57G_4sEJeJ@-`A9mLNtRDEo|Liat zUJUdha~6vGH_J4?`4?1}&c9w+WAiV&_Ww*H?dQpQ#9EtwTD715bV&OrhQ9xKK4yE` zU;d9W4C3@Jj6_+AOXds$}fL=vOAKlAB6TF`?riIJrJk8pyQvWSp0sN zmCrKj<<0$fn`9XK>(ORU`_I>l8QiSj_`e{Wy?5|2Z%ElR*ETK>CM&!HfPTpua1S{_-z+(VqbNy94PT`DHKq z8-f13H+@k<_W>{Z<3RtxK>Eji)r&(cjZD{d0)^BQ4WE zkLY(=rr$^O_qI&``-uKR%k(cG`UhI3|1P4x*fRZ#i2m~})4zo1A8eWadx-v0%k(cJ z`d?|8KCD;3{D-$pAJ%JtKFs9Xn*YLj70^H2a{FPu4(PwuGJROD1p3P@(}(q1p#P1Q z>BD+8(0{#U`mkOP^pCVmAD&kL{i7|@hvzjw|DUx?AD&kM{cg+j;rSiVf1_pkGjhEe z=pSpD{+)>aw_Bz^i|8M3nf_gf{+lh+j}iUfZ<+o*i2jL|>F+}HPqs{dH=^Hbnf^SY zf2w8r4A!~Pzt=MTWkml+ zEz^G;(ZAd>{UeC}U$#vDD5B5gdSF}YUtL6htY!LdAo}Ai(?5pjZ)}=(f@eM^!teZY|He&kLZ80W%?Ho{kvMG|1P5c>6YnV zMD*j9>0d(h=US%!9-@Cw%k(cJ`uDa>pUJZ?;Q7z4mg$cn`uDX=e;m=@-7@`+i2i)b z^d}JgUul{CCPe>%mg!F-`VY2De+tomsAc*QqW`yBrhhA<|En$2-;U_-X_@|Qi2lPZ z)1N`~Ki4w-I}!ayTBbjX=s(&r{kstTPRsOTM1OC~^zT9RpK6)@E<}H!W%|1j{V%jk ze;(04&@%l85&dUdroRW#Uu>EFBZ&U*Ma8`z>m5=C!s-9H7kzmC5Kh1EMIW9&gwub=i#|Mm2&aGEi#|MmaHAg?hn&vOGmZat=O1DEha02+ z7hdgutugwSyy!1CM*shK(SN-$`tNzsKhhZeA9>L~+8F(dUi9+&? z5q&MspbKGeyv8KKm0RU`Ek@gkaNeq+4;sCuH*X_wC`nZ-nC2r-X~12}Xs;a~SuV-H3;qT9I|Y>d)?b?Ae|7wyd_Q}z{+vBRxaW3-|Rn4RM`8;>0{~lYGbiP62fyiMMd$y?#R_?|GxUvw{)KQsPNKlo4b znokOOSGTQWRq|_Hfyh|$pUnJkGRsfGf2jPp_z#h4T<$-tw)|xN8wbX!{6mnJ_&!~J z;`@T+hizVU`N96HS$-zX@{{l%DnFP$^n@vlRATNwueSXBQabu4R|?BSk)s z@z0|9uNl9{j32vkdVApbkLl-}86uHNg9FS$@#|*z@#|nXoY22Vv?-d<=I3wTd~AK@ z@7yF=dGMI;`d3mxG_=nTs^md+P;E^5f8O&KLBMnV^3N;ec`(wF805GUl*F`hhN2b z3YFS~4e38#Kd{n2Rxuy3;T!Tu{@a;<-&$d>&g}>te?#Zrc>f*1V;SGK#t&lfG zz$V6rX8bMTy~^Wni60xdO?f=W-{$lO%fE?9NQ)rCSQnyx<8O2RL#Y4lH~#j8^|-k1*lm6-dnl{5-MX=NB&Pw_5*;{yU7< z@5K0XTSfoQo*AH?3*^Qt1d0xTzRn z&TIeN1aP)H_1}i_`EPUm=WV}5|9P&TyiuWF&@<+6P9=x=Zw%fa^pc;P3qCadTh21P zKkS9SswIi7s{h~I_z`Nq@cdoauajo~B-vk}jsMAhg3$2;)7Sb4T}0{{J|Gji{=J0# zSF`=D()Uf#PipJqia}GXIsLEV$3(RvP1rxr`NK?O`iJqa;5JJx;MawGe)EU5GyY)x zHOn8gRzBb3@II@T{m=b}Z2gD)2L-N6EXaTOZ7=;P$bXpsz~TH;FZsz);zR2{tey1( zZ2zqEpEb9BfbCD-xpfga|H0$8;Qw{1z3cU$d0Qfvr&l7CMlc>T+J{Za5=vHpnp#n;X4pU3NOic(EX z*smja|CK)ee(pZ8A4+6SwBc8@{e|qtdEfV!Mnsj%zv$Kf`S(PE`yaQz1~qkC@ZWvN zivRrkBRaqQbN7YX4@_S%SQ8WcAA8t}|8o`g;@5)t$MkEsR<{NJ$3L&}zj-c`R}R^# z_EXFIyv)BY|6u>2KDfTiJRYcQ^)YEbU7jBtQnlvK5B$zQiyPw~+i%5UIX0(~dJy${ zf3G?IL;Vlzf7!{;T3`Q!-T`^X*Dqb=@9*8ynEtrPUqOG4LXY_A&+70GJZGV>-=5E( z*82QkwbW=2_}{t8^e0(q&3>@oBN~;9uu}_x`zNsdRmhiPh4HfsckSlCcOm|JP<+ju ze*uiuhN4Aq{zq;0QU;;F=J&rj_s=l4 z+}}96b!D?)1Ssa=L<~gpq;iBaHl~+1>T98(xW8-2?f>fVUw{$LZ}qhQ&8Lo@$)(5#PF?fm1P)cJ!~`TV0oUcjo_gdYBnp;}M-U9lg;K%g>7^52a8{-<^F zCfT6*N9JGg+phBY$6D9Fe_<)ln<8yKU8!v58^b$)UhDm*@HWq9j+y-hRm>QRRBG7! zBkZ>;+t(ykX{=K3o6cUEM&H`|IIyr zb=f~qU=yDI*uB5S|C2smB>NZNe<`gkzR&ynXUXH&PYvz)Ma}h}xBU@4V!8f?^e^u} zHtpSi9_w!f{W}4&i1)A9>CfH@djhpugp&VI`-%Ctpe(Y8{#(T3r@8%A<8`ur6S#le z#ePPraXpCQ_M7>K+;TkRr*(UxZac;J(D$g>R>v=-Zwc`7H>CbUl|K&6vIs-*7bGm_ ze~O=`{ZG99r6|?J8=QR0g`Te6v!CHBKiOOOXt2n=3l@=3 zE~YS2iRU01h+oP6K)se>_w!cxaWj5A96wBd5Z{ryc^;yxjNkI(R`?w<Mh~yDHkF^3}$q{qeS+qCDS+@UNu{8KnQaAC{^--*=l==%hwbQBo!^(xFTURmY}tfTelUHTqFElnfB!`dzn#C1 z&hL9#KW}pDf_B~{FXLy2|5#)CgXPz;{bxk`FAFsB_ptb<5gy&&)bQXp9n|o!CZ1%J z{ioyaP=V;b3tO}P`@Xo!`R}fi=Q~B&`A!__?N90aUq2Pfzsvml5U{~m;PclDsJ*IU z>zKrUT9g05r3?IbA?IJ9@^})8uK)g(m;Ea6e*`%3!k_OG(i?Vu5t~Z(zYOR5_kQ>T=2XRChTwmZ@DIeF|5*tB zQuP?E|1)3v|AP?xCxq!E{1YwkT>t-J2>ub_A83ErM)~~H%m17q{6q2oPze4p;U9|s zO(FQt4?q5w69Cg+XENI_CN7-*jtC1rv&Cxv%vlK;e7_Dz^`?C3!~Uz#{Xbs&H(1;F zozwB-*GT}w@eAGmWryE+#D5zJCR`YeP{!cKLp^R;TEZCH(%9-Tt$- z<;V2$Pq>y(KUZ3yzt^dJQdax% zZ@tRLwG_9D-gH=bHj2v+mfy3J%Oi#Fweq+1Yd7YT+`X5b=qL1aeZMrN@#9RN@o&D8 z3WWNZ+Yg=G(fz#=KdPRZ7+){!_v|fb{5R8Awdwa!f2eG=v5Ay^m$9Ro{kk+Yu`cOf z4#oSjyaL*vQ(0_dHvFHF3ZmZrt7y^TD*rWKq+xu0`VZ*I zP?R3w$zA9Pl*9Itf6X>$MLxa%Dl!d7ZnnHx{uOy~|HC2cH+f}zXwUD#?~CP4vX6*v zE3MyKS0KuJo*&cP`(G~W7h2l{|BHD2z)t>}d;iO2{;lyZ%!dE`H?PcFe>)r3fBlah zHH@EeUUwcTC2f$mPP8(9=6uNWw8?%5MDTj7!l^b!_9KOe?1x;4uay@=hTDGP%io5=0EZq_40P(ANe2f{E7UJX6L_H&j(zS|3UK?n!lj+8$0{GcIPjG ze>uHp(ap=(rt7Kizb`%*+TXgcp?J+UXV(LaM9r=7Gn@7CuO@%-_AJ{syLDyrh0fT2Zp`cx(rVkwl}#T_*H3J-oASM7 znm-gASk^RDtW-XBu0ysndSCIdbfDM`~aBtCm zG4j8>-v7xJgtK8ie~|ny^1m`MCjZOu{Ehst+{D#Ju>VwB0A;}C?Z!W{Uu3^BK`PlV zAwT*2j@Hkw)BSsa53QdIGeqm>G=Cf7`CIt<(<#&X6SG)9hgmDFKhgSA(WSHV-)R3c z_-|`%|Fd2H<^EOJ^S^NWb-L1i$pK!o3i&ztq}om7#L{@2ETw ziy;Cbt$zq4@beqM$2P>}?d&H#{{jD-><8J8qORHL5BVPgGx9&kel%-8?DTKVte+a6 z|Jb}gLH>`hAlQE?+G3kHE^jA4x>3KktMXH_rtr`18&wpOl_xI?1 zTYrV^PkjkJ2QF_XKV-kier0kLV!vShW{WP7$wLjCU`VQZ|IY5pW2ruh@ipZw=f z^7%_ie>-LU?cpTwTJZi2(qG*lFylevFO9#MMum*OQYAqXv4+}QXXj_yWUR_#JpZ9n z&fUXFRpUR-SAVUy->h${7IB~a2Z0Fr59B}i`wv2fbp2DBW)ma(qkj*>1~Zk1C-bUN zT7OW0xWt6m(*Cg_-ajU~o9s_f?Ct!YCAomSisQdv+Ef2nWBhxM|6<#|<@wt^xyc{! z?yk$eKbF|;4}6)zIhELj=sOPOp-dKv{y&Vrf1%|WbRo(w zf7X}2sQbu6zU4&^#1A`_j|J5|s$pM%Pm{?cqy5g$pXE2{=g;;-&YzuRo$UNr#Y_?t z^rd()wzE5$*e9N6nU2QOYXN@pS*n}6>G`<6zwJxcR-S)5-^pt&eZ2_ojBv)T$$5oT zyHFIfwBJY3coU6E`#m3Ozg^_7VE@iSBQGHNGs+KSPC{}2FltXBgKboZz>JIh7x+I1 zWc>Jltug*x>_>tBGr+&!__N#?|DQ+w4}6)zIhA-GqJj4NSfKqr{&x#q1)g1Q$hAOT3mO&;{5UW zxfqN*5I2Op^!+#^|Ams`{#}j9f1vy=VfnM;|B=S{50u~2M)@t_trA50&dR=p`G~r2 z9TPRoJcIPdy1dZ8bI+mjX8b!9ct84dxA;9uVU4`Jk|DqND=gfHXg#WB4;wJDnA5&WOQ^EZM2vbf8n ze)GK6X`TN~Q7Hc|>n8>O=>%}(=Rdu!w+||F2_^n_7v;aEe1ZQSWX0cawUhrNjq$JP zFECXR&Z+dH5cOMsfY0}aj~`fmfiF{7;J*v^dk*~vWwKE8K#cLT>%VV%*{=fsM}QYU z`T2K}e%SfBQ=`NB{KyY~z`W4^rSCUc^*{f^5d7uR0b_(e^TVIL6@q^RazOZd;?KVq zg8vNRpCBN{A5Z!JlMwu4!aopy-Veclp70OEpZy>N{|@0Fh(G^f2>y$Ne<1$s$07K0 z!aopyzJ5$#=r#XZCj0~OXMZgO|1RMlh(F&Qg8y;CKM;TRUNS0f1mIV#Gh{u!T%!RABaEuNC^IN>50~#lV0#!f8#qs@Sl(m|7iW+7k~D1A^1mx zf1v&0zZQc34B;P&|K1S%W5Pca|7Sw*pC5kwFDC$|zs{uZ`>P{z|IHS(f5ZIz{fJ!e z#QRNpVtG-i3$xwLzH@$yOAypF-`n!ntZ`G7iw?%pX^H)Uq z_%q0UDFQVy(IZRv{Xw(+Q+0{^ZOad$&+n88y?HnLzaYxVAiF^CzbQ(CV-xA^E@MY- zkfJTKTj9rLX4Z2Yzs5zu0{^?6_?PA9fj(o-<Nm=d3zx65~*HYXrdedR$*(feQSbon=E*IZ-;BV{KZp7(z2|%{m{uB-QO$mW6c0xFYNd1Ey#YG>8pD6`=~!u zw%XW4O25n4QO$l`nwnUb2snr0ec4_x7(Kneo!g(tdTKwQx+IGAUvD4yE5Hw2zbg|^ zf2eGAu^j(xA^1DD-^7*>1k)Xhw6{#(zwDF?_8^IGh($EDW6AG0*X3WNy8PAoH}~I| z93*9b@x1<2pQrMx!I5=`)O~mXfUThZQ(0_dHvGS?&;OG5Gv9wHt3TNPcR4>}PBrNi zM6v%5>9wKgk;7H~YraU6ykVLCL-mQ$BfRwfCCXQS$-h>gXJ8J!{|06?-fXtKS^gE_ z>^?szmCkGbPF@*L&Tkpc^S99ZVR?%jBBI+$&rhst5am73|7q_1FPHTTt!;vT%kv{t zNg##St?&`Oych{|JLdh9QeM4G5M_; z|H>kOQpkQJfLzt~gI$U5$08}SxYX?*`5(HgYQ}^74>*5<{EuemzZMD$S;+sO`3uco z(E5$pUW3ZO`rq1}zX<;2auJJeUQW-i$ho?oaStBTDsasF3n z{*x7$k9EhwSWB}0vq8Gz^1fC7(fozxFG9Lz$p2{8|FF_OIKG-iH)i(9sN404E1N$0 zM~|ZCXZ5AD9LDdr*nEGM)?WmU`1^|*x)v9>yq)~e{E6mInQ%3}erqRxxqlU6zxc|= zFMoGF%zpJu&tEh9C3G*6D%fMOjL86;EByUBJ(tZD)Fkp+;Ya=#`CsIJr3y35_KT7K zMbFRm{ctwph0G)MpPql|^DnFMldawLCkF}81Odv~wV(VizzYuYhHESQ$p0e$i~O%t zWlH{+(4XAC(E54ty!#Ze^qTAk*$=WGM*EQ_y8Q%@P{wNfZfgBJeEsQ^Y5j>^Ozfxr zT2Tz`lZqpkFBzltCqM)catmqwT;FeHjUL&r6snT{V~t;)aj5Or=}P-0+84=Y#w?aG z89=Maev$nuN`~F|P492O^GkYvLtphr zoIi9_cS9_O2!v!m1QK}t6Y#MOad|uaBmV>RkNl5j{SQ0&%b!n&>_27=>_0X>|FL<0 zg6tRBuLJ=)|Ay)FUE6UB?C)Che}>F{(es~|J>^ddD)yhO<+q}|o&1ph1^Q3^m$&}~ z>o;4l%*Hr@u~Jgi`=3_wOV3~E`HOIX>G@0e^OutTcFOwO!%0>157J-VA28!V<1dZB znMO6fe+=vC;Onnnk@FuqWuiTtR5kwNeD&90i9}AbzNuQW`{X~6|3LmjsuW~@xcB@= z8WplX%DN^3WqRMKkU`uh|A+jaOeo0ytcLwzUy%>M!q0#9K7T;i1kdLe@B&vj|K96= z!h^CrzrTz8t?s()`(w$v*uqAAzF07ogQL=aaryqpAywlki1yL?ft~$+1GNWga|k8> zF5kZ^@P8U8`SE|OG5+!Q_kb@`7^%b+h@SDn|J#l6KaJ=EU#4(QC3YeDjzf7UlZB%H z5AO^szx-L>@`9No5BZiCJrF?G^V zZe7_dMMKzqrm?e;NNfjIoTNhpGW-L0CaR>f^ z=PVTR*WKq^Ui83;{iuAVAGy7aZqe;6+ooR!+24~R;{0x{Sw{Wkkfpanx4-eHg6wa! zxb8f%|3E2zE(RkHNrsS@z8`1gzfe-#zpJtNN1*&IVfnM;e{W;_2g>hhqx_ceRtchg zXJy~Qc&6@K$3zV?&mjG=E-&=&+;gbB8UH;Mct84dxA;9uVU4`Jk|DqND z=gfHXg#W<(YnZ+x{};9v`RCVf%g!U@P)mPmq3GZ7>!)@V*Z24X+c5s-{)76!{}`+2 zPnEAWCi@TTZ*4iqAo{N-^yjjl5J zpMAYKsUqC|YWfRIRfKaY{U}8J)*m)ChCh~H;L8*i`0v8~o{o&RBfyK7{O~b3=?*(TcWQK4zd!4TKVV+y|I+uH?D`)DK`;E}(g9uy}aApZQ(5d6D@e<1$sXbAqt3I9O+`EP~b-y{43@n;Jm z_@5#C1M%mJA^7(R|3LiNQV9MR3I9O+IqZn@T7Q;HPqhA=^n%~|8+$cG{wKuJAK~wf zKmSGu{t@9HXn)x6gy26z_=n>EW(fW<;U9|sABNyRKm7P#P5?}Qok`#C*N%uCBr{vk z{tff<_xB`Zc)v_f?9UR+xc!y%UG|?8_WyK&nAiRd);4~Zb^Q2slKpV}Lic}J;wO3w z_Dk_Tz&{nbf6`_DaAE(>NnpT|f5%)e&zMoMTAB!;r9p4_D@wN?zb&Jh(5nlCiLdr?Eiu&Cxh$)z5k{t z4USEux4VoTxdBE~k>3hGF7hVtuQ`ri@}rCe{&zd^FU!vZea4*2r_C+UUvQ3iS6nl~2lQKmM&(`M8$icF~&-E6+x8`N8sgc5=D+ zz5{<-zjk9j$=!R|iGD&)*Y`_P8b8kT8UN-hsX(Zox&6?|9o^q6@uTXgiShNqe$U>5 z?6;Y|s!hL-`a@-_jb%H)yNn&x?AN8KiFJv9b12@I?FEC;)BD@G{fVrn_5-R*qFDd+ z_QCrazzSJX3Lx9UlA7fPf0xeo}WN%JUPE*bSTgNM(>9~gIq$Vk|p6CBG_x_j5 z`i0gu!N2AC5vrt+zvkZma+!ZCE|QqAApY~;yfSb7h4aHp<4>H|okvPZt9pJYBl{tG z2Cv8F{Z$_!`;lTq_CvmiUoSA_99QD|v1KZCDZ2e5|3i0G&3KUi0q0MU|IzIHw@zCN z_sRdD`3uco(E5$pUMsCx>+=`EpFEy{MK>>}=hx)tdG)vI=XbG?^Al50#w*V+RFGB0 zW|=tut2F<~3e3m4V_~c%S^wD}m&W9MyZ)p3i-3#fFXVqT>wnnkUv8hmzCXL^qgUkn zEjHhurS%tKL00GcjWmCv`BNr;jj!LfXuo&^-_P!up1)@HOAhcNU_BPgm<+(VI?>AW zS4CUkOa7NYg#0h^zsUbm)bQ3w-A=fEVO=3`v)t9W8e*8e`x<;;P)f0V*iIV zu>N$)wEn~_=5MtAB($66Z!~`!Xy>bLO8=AWmo7h5a8;zAuC!mGKqSk>ES51Dko^)s zkp0r-&u;yK)}Pku`jf8Z1FvZRB%}Ags@`85C^_j(F@VtgZHVV@SK|E(yYYjbp9}q> z=jZhNT=&0f?QSUpc=DwibqST@f2UkB#z{bx_TRu0eTv(O53(Oc zO|#P<@;?M-0cxJU#{5mADj0l$o~-*1p6;VTWk}@WCQzfCX+_kUOM`8BNHY{8uy+J#`$~l`AbQEJ6KthgPw6x)%?Q_Khj@;C+RP(f9UoH>skFZXkvZNa*nTG zk@FuqWkNliR5kwNeD&9Q`^o;$`FZs+ljIO$Geh==?2nbaRAwUkQ-sFO{?hs>*`G{m z8ecyx%3^)|YW07(d;qqJ^AFKEp8GdhWBhxM|6<#X<@wu_a*zPLyX&&=k0tA33mbKR ztzKqj4(Y$Re1GJSs{9m0`(E!KPoefeZ4RO2-{t#v1^!P1B|rY7#`wqI-vho(VWbjQ zAbQ3N|63d5e;UyTzD(hqO6)@P9f$HzCJRLi4(|*rzx-L>@`9No5BZiC`X4{+R6Z6| z_lO|<`vv$knM^YJPx<+?{3iYU*?!3Rvy-edyLDx=zJE~A&(n5xN7M6fC3B2TN8{-*cjbZzDNxAUF6*3#FD;LZqV?3$ccNVN+^F-!Zs9qC&%D(&}tsQq@4 zzk>Zc3yr*hx$LJVXQS_pw0xef)6;{(pXo<#Z=+jud&{=z z*FpC8r67?a=LS{HY-O8!fIokL*8Cil2+6j6#YjFSJ#CcV65c97wC`+V_U9R-KZ;UKOz7Xa=TLbw{+$ZEAAP!8 z{GM$MN$s{kU`fb^HUL11v zrxuF-9lw5R*L4Mdk3X;t<8SUis1N*)v5Njw`D$ab|FHhnmV*qU{XL;Sm(Bgdek{J9 z%b`&Gw1x-Ux}f0!*=$0|Kh|HHqFEln{~0`g6Zk)!_{*gO#t47rhd!r~Lm$2>vnQABaDH zDFpv{!aopy_G$?J9l}2lfBuaS{1*xTK>XR!5d1meABaDHBLx3t!aopy_PZhYcM1PM z{Q3V3!T&hnABaEuVF>;`!aopy{_jKZKSTHj;?H^^`1c9_K>Yc4L-4;y_y^+8&V=AE zm!4?-Iq3zz^*4Sl1pf)K^hfx6c|hciDeZ z*#FZ7T7LUC+Qx5}jvv2HvLB9L=>9J|{B|S$`vCt`=>AEU{ljxWm9di$wc{Uh!~Vy% z`+r`?kIhJi!u!AJ{c{SbnwV(OBHq7MQ%c2hycH{;aO+|hy{J6-Q9Bbt`evONQ z1^#zC@h{8I1AWGv%cspP&|h$+pUQ2a{Qa)w)18ln@=v*zAJfY};aWc3eqNxz*QtC` zR{Qa9y~@Y66t|1sbXa*dipvj{-?Nj;#rGZf+xoQ|^GWXB%TDwYdb++}n$q}jrqB2{ zUr7Z*{mku$PVVUbUWp%7Pfd)k7xsJh7G%H8^i^&8ebgT+TWu^mL9ol%QO$l`nwnUb z2snr0ec4_x7(Kneo!g(tdTKwQx+IGAUvD4yE5Hw2zbg|^f2eGAu^j(xA^1DD-^7*> z1k)Xhw6{#(zwDF?_8^IGh(%PsMsu9&@(=m!L!Ezf|BcB(QuY_mYftrgD!&>WS$9a? zKMJGn3hFgXJ8J!{|06?-fXtKS^gE_>^{FfWc?f3h0C7}<|h(5}Sy zW9mQ8# zk5pkAU;i-6o6=!qpKi?Tle+B`TiNu{EAssooA1xk`itl){QX#`4oouR@^>(^HL2cJJ;QHcE_|I6?F zIWU!sGur=CwSO%KsVX8YDDD3VJpaAv{XfF~g6hx{M%e*^+Fe?!j? ztIU0Mx)ddup+*0PHL(74%C!E(E++O@#yE-Q4+0VF z|6`o$u5o!Q`^jkjK=X%8vKn9isFT;ieX<`kf1vrpRXl$Xwg5Z_P;c- zerkOFWApw5w3?p((DR?dm?-ua6cWt(K7HeWc0=kd;mKXZZ_o4Ece&qI+-Gc0GllJ%Ih&aDn zYp79wIb`YW(Cu&ZsUZ6sEv`F{>_1S7pNplAL5eBlrSHcX`7aa|_vcrC{vrC`Z~npa z22HzsKfIuSN1+BU|BwHDIq1a`SLgcCE|MQ9Kz{Ob;3}7&*|jb|n7^B3cYo_x>W}BY zP0_Ky^@DcSKMu*U1pJ-ce;(5HkHe@vd573XL=7{mv;Of~WBjkq`p0r({L}i!n7zyv zbBYnt$Q^>_X>5LeAPy;_otlF05aihx+~2UykeTgUVb&$v^J@K~)s^ z4_tqFvoZcd*I(W>tv?h%5b|_F=a=Uh_PESH^HOO<;_tG4GN&r&LlpB5>9wKIpPpHN zRB5w4_&qQCRp9>!P{90~=|OI}e#(EZk?+4JZ($AT_uu{SS12Yi(f_6W8&>$UKM28J zF5NRWMfm&S&;KX{{|MxO@MnJbvy~A1X9)j5`@_$M;2#tIf%vm`Lhzp_`~&gl|04wd z4&fh&Kl|Sy_%9Ovf%x+uh2YN#|3LiNI(a(2*YnS1!aopyzA*&vsKe<=P>gy0_&{-O9k9fJS-@Z*0u0Wke_ zCVlt+kI4NWTdek<92L14?B5mClh=r*eCfmf6PNv$h5bKW;NW{!{6lVd|E+EOx;lRR zI>}HtelGjB3;3M?MC|Z;1M%O7;!~mfk6iW-&H+`%PD0d5`dw>YW57KKx(W1qzNdIcfK?YIY|N9Z5eC)D+9!O`Q7zCE^{&l;k=it|$$TO{BNGj2*cFMpKdB3O_FLCgX?W_$7z^0{^?6_?PA9 zfj(o-<CVSO`KMgVkLl%~a4ny{epH~p*QtC`R{Qa9y~@Y6 z6t|1sbXa*dipvj{-?Nj;BZc>`_}lum8}mu--pfw(6MDM7Uz*bRai-7sH(yBwLjBC` zhfeP3{$7b6RZmTfuNU@v_7*gLnCYw9^!unkRJPjKL`uKQ*ip@XU7DI$m-H`(;(gg( zFc>|(zn$Bk$a-o&pt>ZA^{{Wm5DN!edKuRYc0sr+hiWZfZk|0s;QE2#fe z7TcH&|3_pdIsZ$}M|l6Gto~sC-{t(OIn|_75cT$71%QXE{MUSuCOP3C{Ri}9C`ym; z()kA{U;QQjT78~@IrRP;nALc*+45%jSA@m==jEX9wSOnCjI&{Ve*(Q9mbb_uBD$^g z{KUEjQQq(T3)bBGUoPtxTH6HwmghIAl0yDg=lw6mMG_Mh#DD&qSLUt1@coX`_!H-K z=aEvN@jLkA}eqPCc$%@X$ zx?wQZm(1Tbz)V_XFzbsdEbf#4ArK+|gY1WHKa{4}#w^sdqAKz~X#JyV{jNe}RI(5| zTK{0Wzo-DJi3$8>T=pMmHD00h54-U@_iw_!|Gw#?SLFL?Hs61z^%o&c{6^h>(`3xz z0++XwADTbW{3#Qz#@BD{nm=k_*^S?{{|D@E)9YUX|J*;K=WpnIbb$}e--OcA{Eg;sLp*;AUw=Af zT7O~}6Z@&ZRusb^2dl-fYDDW#fCwPu7Sj58*z+^8UstvLI$deMgc?Rt8)I5OH`^s_ zf6Te6Q*n8_`46o>!Tg8pSF`r3T2f1Q==}{w`=4n4ldgYu_WOhJ{@0b*|7eFFt-lCs zOzST+{_FOqG)mVdz)D1IMx3E?{O{nL$vBCVfGVwj(E3MFHtg&_J^z9EPt(tTihW@h z2TiuiUC2fDgXRx(egrx{$IkwJFy;@!7U=fFuKoG*>9F@FuF&%zoA)Q6)#U$>|5F$f znVtNT|Fc&8AG`irGxm#~|Gez^{8TC}etuKbWt%uIZzn(W{0H=()_=U$e_;N#1$jx8-ulbr^OutTcCfN02R-AYs`&@$ukH_+@u2aS#@|e% z8h`%>>*?U@uV0b#A39~CJ)BfE{^NZ0*I`6`()CYi znoW%CkN!Oj8_ZN5p3JL8Y5hR~;t~^LOXp{Oz|PNd5dp@oR{Jw97m$Z_{68lLopAoW z$A9tIY4G{m{D$ewz9)uS{Q5f=((`EH$8Wf67kWPwZkfVED)A^p&v-rmy=V3NU*P9w ze>t?CO}l)*aV{^$!tmoA`Oxud{r*Qx@*`!*Pyar5mCMgYlAjRy;Xjb2uU7l>USs3G z_xmS9J$`-f11mqWVJwywHaYzNihlnLKYs`Mwjq_zJIWR^Eng^p?5g~JT@K-G7p^bN zNAhHIq%i+@19JSa*ZgBl7LPu*Q2ZV(R9X=h=)VbB&RWsuGJEu4jUL~$J!JhYtu$2r zsQl{cKjBaKOXJ`bko=JRcnOXyMDj!OL-ON~0O=3ukH7Sg|4#Zt`a}BTj{xZp>5sqk z(D+08L;6Gd5sqk(D+08L;6Gd5sqk(D+08L;6Gd5sqk(D+08 zL;6Gd5sqk(D+08L;6Gd5sqk(D+08L;6Gd5sqk(D+08L;6Gd z5sqk(D+08L;6Gd5sqk(D+08L;6Gd5sqk(D+08L;ACp^oL)! zUSr32&A;(iM5cJ+0@L$rg=;Q2WAW(A{^ir^S51FLv&PN;>{rLK2>dbP#UM}ov4xRh zdp~rmH-3Wt!p|W3z=uu9H?ZSFi@$8(s`Y;ku;_^V(m68v(6cYRxVZSm=MP5D?0e?P z#Vf6eEh&2&+dEiHy?j~-yH|`KXJ#C2OfX+*YEh) z?K8K32luq>1AAaGdqo*F(~8Qv&KUuQ_)o}~X|G5r^W+U=-FPm+UvDSzt>HN-OQ7IY5}bbmFu zpQdd<`4;7McX!cLFgAhmU7uW==uiBTDEBDJk1w|*31xH{Ta@LOWPy}M`hQ$rPs;HB E1L(bxO8@`> literal 0 HcmV?d00001 diff --git a/lite/kernels/mlu/roi_align_compute.cc b/lite/kernels/mlu/roi_align_compute.cc new file mode 100644 index 0000000000..52bfdfea15 --- /dev/null +++ b/lite/kernels/mlu/roi_align_compute.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.ddNod +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/roi_align_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace mlu { + +void RoiAlignCompute::Run() { + auto& mlu_context = this->ctx_->template As(); + auto& exec_queue = mlu_context.exec_queue(); + this->Run(exec_queue); +} + +void RoiAlignCompute::Run(const cnrtQueue_t& exec_queue) { + auto& param = this->Param(); + + auto* rois = param.ROIs; + auto rois_dims = rois->dims(); + int rois_num = rois_dims[0]; + if (rois_num == 0) { + return; + } + + auto* in = param.X; + auto* out = param.Out; + float spatial_scale = param.spatial_scale; + int pooled_height = param.pooled_height; + int pooled_width = param.pooled_width; + int sampling_ratio = param.sampling_ratio; + + half spatial_scale_half; + cnrtConvertFloatToHalf(&spatial_scale_half, spatial_scale); + + auto in_dims = in->dims(); + // int batch_size = in_dims[0]; + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + auto out_dims = out->dims(); + + std::vector roi_ind_vec(rois_num); + auto rois_lod = rois->lod().back(); + for (int n = 0, rois_batch_size = rois_lod.size() - 1; n < rois_batch_size; + ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_ind_vec[i] = n; + } + } + + auto* input_data = in->data(); + auto* output_data = out->mutable_data(); + auto* rois_data = rois->data(); + + std::vector input_tmp_vec(in_dims.production()); + std::vector rois_tmp_vec(rois_dims.production()); + std::vector output_tmp_vec(out_dims.production()); + + std::vector nchw2nhwc_dimorder{0, 2, 3, 1}; + std::vector tmp_in_dims; + for (int i = 0; i < in_dims.size(); i++) { + tmp_in_dims.emplace_back(static_cast(in_dims[i])); + } + cnrtTransOrderAndCast(const_cast(input_data), + CNRT_FLOAT32, + input_tmp_vec.data(), + CNRT_FLOAT16, + NULL, + tmp_in_dims.size(), + tmp_in_dims.data(), + nchw2nhwc_dimorder.data()); + cnrtCastDataType(const_cast(rois_data), + CNRT_FLOAT32, + const_cast(rois_tmp_vec.data()), + CNRT_FLOAT16, + rois_dims.production(), + NULL); + + void *input_mlu_data = nullptr, *rois_mlu_data = nullptr, + *roi_batch_id_mlu_data = nullptr, *output_mlu_data = nullptr; + cnrtMalloc(&input_mlu_data, + input_tmp_vec.size() * sizeof(input_tmp_vec.front())); + cnrtMemcpy(input_mlu_data, + input_tmp_vec.data(), + input_tmp_vec.size() * sizeof(input_tmp_vec.front()), + CNRT_MEM_TRANS_DIR_HOST2DEV); + cnrtMalloc(&rois_mlu_data, + rois_tmp_vec.size() * sizeof(rois_tmp_vec.front())); + cnrtMemcpy(rois_mlu_data, + rois_tmp_vec.data(), + rois_tmp_vec.size() * sizeof(rois_tmp_vec.front()), + CNRT_MEM_TRANS_DIR_HOST2DEV); + cnrtMalloc(&roi_batch_id_mlu_data, + roi_ind_vec.size() * sizeof(roi_ind_vec.front())); + cnrtMemcpy(roi_batch_id_mlu_data, + roi_ind_vec.data(), + roi_ind_vec.size() * sizeof(roi_ind_vec.front()), + CNRT_MEM_TRANS_DIR_HOST2DEV); + + // malloc output memory on device + cnrtMalloc(&output_mlu_data, + output_tmp_vec.size() * sizeof(output_tmp_vec.front())); + + // prepare kernel params + cnrtKernelParamsBuffer_t params; + cnrtGetKernelParamsBuffer(¶ms); + cnrtKernelParamsBufferAddParam( + params, &input_mlu_data, sizeof(input_mlu_data)); + cnrtKernelParamsBufferAddParam(params, &rois_mlu_data, sizeof(rois_mlu_data)); + cnrtKernelParamsBufferAddParam( + params, &roi_batch_id_mlu_data, sizeof(roi_batch_id_mlu_data)); + cnrtKernelParamsBufferAddParam( + params, &output_mlu_data, sizeof(output_mlu_data)); + cnrtKernelParamsBufferAddParam(params, &height, sizeof(height)); + cnrtKernelParamsBufferAddParam(params, &width, sizeof(width)); + cnrtKernelParamsBufferAddParam(params, &channels, sizeof(channels)); + cnrtKernelParamsBufferAddParam(params, &pooled_height, sizeof(pooled_height)); + cnrtKernelParamsBufferAddParam(params, &pooled_width, sizeof(pooled_width)); + cnrtKernelParamsBufferAddParam(params, &rois_num, sizeof(rois_num)); + cnrtKernelParamsBufferAddParam( + params, &spatial_scale_half, sizeof(spatial_scale_half)); + cnrtKernelParamsBufferAddParam( + params, &sampling_ratio, sizeof(sampling_ratio)); + + cnrtDim3_t task_dims; + task_dims.x = 1, task_dims.y = 1, task_dims.z = 1; + cnrtFunctionType_t func_type = CNRT_FUNC_TYPE_BLOCK; + + // invoke kernel and sync to compute on MLU + CNRT_CALL(cnrtInvokeKernel_V2(reinterpret_cast(&roi_align_kernel), + task_dims, + params, + func_type, + exec_queue)); + CNRT_CALL(cnrtSyncQueue(exec_queue)); + + cnrtMemcpy(output_tmp_vec.data(), + output_mlu_data, + output_tmp_vec.size() * sizeof(output_tmp_vec.front()), + CNRT_MEM_TRANS_DIR_DEV2HOST); + std::vector tmp_out_dims; + for (int i = 0; i < out_dims.size(); i++) { + // out_dims = {N, C, H, W}, tmp_out_dims = {N, H, W, C} + tmp_out_dims.emplace_back(out_dims[nchw2nhwc_dimorder[i]]); + } + std::vector nhwc2nchw_dimorder{0, 3, 1, 2}; + cnrtTransOrderAndCast(output_tmp_vec.data(), + CNRT_FLOAT16, + output_data, + CNRT_FLOAT32, + NULL, + tmp_out_dims.size(), + tmp_out_dims.data(), + nhwc2nchw_dimorder.data()); + + // realease resource + cnrtDestroyKernelParamsBuffer(params); + cnrtFree(input_mlu_data); + cnrtFree(rois_mlu_data); + cnrtFree(roi_batch_id_mlu_data); + cnrtFree(output_mlu_data); +} + +} // namespace mlu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(roi_align, + kMLU, + kFloat, + kNCHW, + paddle::lite::kernels::mlu::RoiAlignCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("ROIs", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/mlu/roi_align_compute.h b/lite/kernels/mlu/roi_align_compute.h new file mode 100644 index 0000000000..fa571efee0 --- /dev/null +++ b/lite/kernels/mlu/roi_align_compute.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/mlu/roi_align_kernel.h" +#include "lite/operators/layout_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace mlu { + +class RoiAlignCompute + : public KernelLite { + public: + using param_t = operators::RoiAlignParam; + + void Run() override; + void Run(const cnrtQueue_t& exec_queue); + + std::string doc() const override { return "Mlu roi align"; } + + virtual ~RoiAlignCompute() = default; +}; + +} // namespace mlu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/mlu/roi_align_compute_test.cc b/lite/kernels/mlu/roi_align_compute_test.cc new file mode 100644 index 0000000000..9cbcc74136 --- /dev/null +++ b/lite/kernels/mlu/roi_align_compute_test.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.ddNod +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/roi_align_compute.h" + +#include + +#include +#include +#include + +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace mlu { + +TEST(roi_align_mlu, retrive_op) { + auto roi_align = + KernelRegistry::Global().Create( + "roi_align"); + ASSERT_FALSE(roi_align.empty()); + ASSERT_TRUE(roi_align.front()); +} + +TEST(roi_align_mlu, init) { + RoiAlignCompute roi_align; + ASSERT_EQ(roi_align.precision(), PRECISION(kFloat)); + ASSERT_EQ(roi_align.target(), TARGET(kMLU)); +} + +TEST(roi_align_mlu, run_test) { + constexpr int ROI_SIZE = 4; + + // image_height * spatial_scale == featuremap_height, width is also like this + constexpr int batch_size = 2, channels = 3, featuremap_height = 9, + featuremap_width = 16, pooled_height = 2, pooled_width = 1, + num_rois = 3, sampling_rate = 2; + constexpr float spatial_scale = 0.5; + + lite::Tensor x, rois, out; + + x.Resize( + lite::DDim({batch_size, channels, featuremap_height, featuremap_width})); + rois.Resize(lite::DDim({num_rois, ROI_SIZE})); + // here lod use offset representation: [0, 1), [1, num_rois) + rois.set_lod({{0, 1, num_rois}}); + out.Resize(lite::DDim({num_rois, channels, pooled_height, pooled_width})); + + auto x_data = x.mutable_data(); + auto rois_data = rois.mutable_data(); + auto out_data = out.mutable_data(); + + // {0.0, 1.0, ...} + std::iota(x_data, x_data + x.dims().production(), 0.0f); + std::iota(rois_data, rois_data + rois.dims().production(), 0.25f); + RoiAlignCompute roi_align_op; + + operators::RoiAlignParam param; + param.X = &x; + param.ROIs = &rois; + param.Out = &out; + param.pooled_height = pooled_height; + param.pooled_width = pooled_width; + param.spatial_scale = spatial_scale; + param.sampling_ratio = sampling_rate; + + // std::unique_ptr ctx(new KernelContext); + // ctx->As(); + // roi_align_op.SetContext(std::move(ctx)); + + CNRT_CALL(cnrtInit(0)); + // cnrtInvokeFuncParam_t forward_param; + // u32_t affinity = 1; + // int data_param = 1; + // forward_param.data_parallelism = &data_param; + // forward_param.affinity = &affinity; + // forward_param.end = CNRT_PARAM_END; + cnrtDev_t dev_handle; + CNRT_CALL(cnrtGetDeviceHandle(&dev_handle, 0)); + CNRT_CALL(cnrtSetCurrentDevice(dev_handle)); + cnrtQueue_t queue; + CNRT_CALL(cnrtCreateQueue(&queue)); + + roi_align_op.SetParam(param); + roi_align_op.Run(queue); + + CNRT_CALL(cnrtDestroyQueue(queue)); + + std::vector ref_results = {14.625, + 22.625, + 158.625, + 166.625, + 302.625, + 310.625, + + 480.625, + 488.625, + 624.625, + 632.625, + 768.625, + 776.625, + + 514.625, + 522.625, + 658.625, + 666.625, + 802.625, + 810.625}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], (4e-3f * ref_results[i])); + } +} + +} // namespace mlu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(roi_align, kMLU, kFloat, kNCHW, def); diff --git a/lite/kernels/mlu/roi_align_kernel.h b/lite/kernels/mlu/roi_align_kernel.h new file mode 100644 index 0000000000..a2298a1f89 --- /dev/null +++ b/lite/kernels/mlu/roi_align_kernel.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LITE_KERNELS_MLU_ROI_ALIGN_KERNEL_H_ +#define LITE_KERNELS_MLU_ROI_ALIGN_KERNEL_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef uint16_t half; + +/** + * @brief Region of interests align is used to implement bilinear interpolation. + * It can change the input of uneven size into a fixed size feature map. The + * operation passes pooled_width and pooled_height divides each recommended + * area into equal sized blocks. The position remains the same. In each ROI + * block, take sampling_ratio points (if - 1, all points in the frame are + * taken). Each point is directly calculated by bilinear interpolation. Then + * take the average value of the points taken in the block as the coordinate + * value of the small box. + * + * @param[in] input: 4-D sensor of shape [N, H, W, C], n is the batch size, C is + * the number of input channels, H feature height and W feature width. Datatype + * is float16 + * @param[in] rois: 2-D tensor of shape [num_rois, 4]. ROIs to be pooled + * (regions of interest). For example [[x1, Y1, X2, Y2],...], (x1, Y1) is the + * upper left point coordinate, (X2, Y2) is the lower right point coordinate. + * Data type is float16 + * @param[in] roi_ind: 1-D tensor of shape [num_boxes] with values in [0, + * batch). The value of box_ind[i] specifies the image that the i-th roi refers + * to. Data type is int + * @param[out] output: 4-D tensor of shape [num_rois, pooled_height, + * pooled_weight, C]. + * @param[in] height: The height of input + * @param[in] width: The width of input + * @param[in] channels: The channel of input + * @param[in] pooled_height: Output height after pooling + * @param[in] pooled_width: Output width after pooling + * @param[in] num_rois: The number of roi + * @param[in] spatial_scale: The scale factor of multiplicative space, when + * pooling, transforms the ROI coordinate to the scale used in the + * operation.image_height * spatial_scale == featuremap_height, width is also + * like this + * @param[in] sampling_ratio: The number of sampling points in the interpolation + * lattice. If it < = 0, they will adapt to ROI_Width and pooled_W, the same is + * true for height. + * @retval void + */ +void roi_align_kernel(half *input, + half *rois, + int *roi_ind, + half *output, + const int height, + const int width, + const int channels, + const int pooled_height, + const int pooled_width, + const int rois_num, + const half spatial_scale, + const int sampling_ratio); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // LITE_KERNELS_MLU_ROI_ALIGN_KERNEL_H_ diff --git a/lite/kernels/x86/roi_align_compute.cc b/lite/kernels/x86/roi_align_compute.cc index 26efd9160c..3c0614ebf4 100644 --- a/lite/kernels/x86/roi_align_compute.cc +++ b/lite/kernels/x86/roi_align_compute.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "lite/kernels/x86/roi_align_compute.h" + #include #include #include + #include "lite/core/op_registry.h" #include "lite/core/tensor.h" #include "lite/core/type_system.h" -- GitLab