From 88847c1c884cffc78bbc7aa9a040ec32f8c37056 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Fri, 29 May 2020 08:04:09 +0000 Subject: [PATCH] add ocr demo --- python/examples/ocr/test_ocr_rec_client.py | 31 +++ python/examples/ocr/test_rec.jpg | Bin 0 -> 6369 bytes python/paddle_serving_app/reader/__init__.py | 5 +- .../paddle_serving_app/reader/ocr_reader.py | 200 ++++++++++++++++++ 4 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 python/examples/ocr/test_ocr_rec_client.py create mode 100644 python/examples/ocr/test_rec.jpg create mode 100644 python/paddle_serving_app/reader/ocr_reader.py diff --git a/python/examples/ocr/test_ocr_rec_client.py b/python/examples/ocr/test_ocr_rec_client.py new file mode 100644 index 00000000..b61256d0 --- /dev/null +++ b/python/examples/ocr/test_ocr_rec_client.py @@ -0,0 +1,31 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 + +client = Client() +client.load_client_config("ocr_rec_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9292"]) + +image_file_list = ["./test_rec.jpg"] +img = cv2.imread(image_file_list[0]) +ocr_reader = OCRReader() +feed = {"image": ocr_reader.preprocess([img])} +fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] +fetch_map = client.predict(feed=feed, fetch=fetch) +rec_res = ocr_reader.postprocess(fetch_map) +print(image_file_list[0]) +print(rec_res[0][0]) diff --git a/python/examples/ocr/test_rec.jpg b/python/examples/ocr/test_rec.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c34cd33eac5766a072fde041fa6c9b1d612f1db GIT binary patch literal 6369 zcmbVwcT^MIw|0;wMG%lCC3Fx(1O${41f{E35R?{?4pKu25E43y6oFR|5m2d0iFA+> zdPk&&CXmp3A_xNs=L28MqJBhy(1MkYoEhO;bZnVA0?^^BE; znf33*--rC?)-&|9wDio342=JV{I49P1;lX{bOnT?r4a+2;h>@AprLetz*L;{|H7k! z{g2U{p{1i|V5Guep$1f+r^2VDrGloXBBn-%Qs+T*9Q2%`iaHEj59}Dld@d-3Cx1FC zu3O&9{cxBdq5Rx8f{B@jmyiFVq?EMGCE05#s%qD7sNcG+r*B|rWcEhGmpeIoFwrt zwSSrY&xl3*pP2nWV*j7lG>Da!hMGKD4v;noc+V339SABAcn13lZ!xYv6<_rnT8BD5 zMRuWR$R;=m3{kPMqMj(o5n7aRyGt_pLGT%$Hz4#IO}V^l@15J(*<_byp+_mwhwr+& zwKLiKn)Xhai$6m*C?J+*Avokm-*5f9w+B{QG&DoZu6Tb4@3t(`*pkcbz_^y$w_)g- zHQ<=G8nqOas*ChZ*&nuM4jrCF{L;V%k;H+!7c%JJ3C^}KMCwWaW=v{3O3Z=?0^dkEhb8wtB$t83h!DGKF0@ z{3%pg+{-8VZksfmP{+SHQB?`N(tHC9Yd*py2>_K3qi!D_m#79reF%EVogNU%wz0M2 zKd9T?r9L=|S!XGLt2~q9ih2AT49VhbM74dNwM=RwWroCFCc-O}z0qXwHQ8bt3g~4u z{E9u0pp}dIO8DRstGPqch|dWf9C_tw-;-F__DDy#lyG=AMi=v(khdTV zKK(~LLdEIZAdgZBR(eIaQQ?+tE(JuNWXdNUi4fdRUN_&q?=|M*8n()`$KrUk{fynOM5!H_HmIvavr@Ud^%yM2PC!dhYQ-FavR-l9DkZV0#Qq^0tX2~mZV5$~v>HbR@5&=|_ zxM$ZPd_=EKOw>FVZwKc(k}K^7LJ?LN#LFR8?(DNyE@6MY@I#gHipcR=N4d&E(3n* zc^%BBo!=N(O#PWv(8PYQ4T$3urbgxshpV7$Zr*OTm#ZEV&;9CU7p&zl%+vI}Vpec{ z=0oD2rBfQB;GED;h``1{5J@jK=~KN=_Sa;S7_t#R`e)8e=D`!<>b4CIb(R8Z`#w=M z6XX}t>sbbU@aAhTYjQ+`BKD21oJ^U${#IF#Gnq@lrmn6YNH4a$)5~2Q{WXShSwqCX zOPdC&fb0?G9fjODo%}x3juvpU$Te-=dyrrKCMcZc#*xusHxBBF?1ol=yN0HU+l8-= zRwuypn(B!xsUByNXYxuaB~Ce%2snd0Ntd3LkzvE3+H8FxtUupDpjqsoF^t~BkN z$FYA@7J>zBog=qcQ;j5QAR;cfAnWocx8h}nZZd)bxVHnaCKM1;V&R3PH*>76-xyFa z0otFj@`ic)f?oY!8WvusARaFiW3I)orL=v_&PY){kz1qoj8|(;6h@#~00=*oUy~w{ z*lm=jtOe3vcGN#Nk}4l#uj(w(&dWVe3gsN}^J*Kr+&{ZLOF|1(eKEaO`SV2khM zZG6Ay54=#~qp8L+5*Z5K-?RLZxySGq`TgK zF%31zcwm6#y8l*XTmRA5+h4q&W9?_6`REWIv$Zz&cy&fbab!yyY;^UFEhmUv8q1o; zj&{ILCNp|jIPEU>(>&r1P*?zkry7qoIwzIuoEv|Yrs7f5rti4fbENlTnO^Ri@SX)^ zk;qHt>R%}M6vVw7(=4))rj}4Bq<%@Wqwa2rf!UCRFiQPMZ4ul?0r4USP+(1Ef^+BL z&v_OPq1|hPw!e%!QC=-dGMZl)m*r-Dmkv$EE;@IjxRF1z*$IjHc#PV;@kLk5v!6L~ z6fb-eSsk+op(>jW5fz5eab3{nbb*LfG-1gas8H8X`pTme9d}LHo072b6@w38>#@O6 zpNiZ0k`i!fvP4>Q;oj#)jZ+@g`jd6kZn>T>N%_$}>>%Xi4T7;0U}=GJHJYHUf7!ZD zn<-6CM=s}i4#yuZ#Zf=7F6`4LIE$@AVt-o==De3D&b>+YfeYv8puxkX3(}Q76O9%N z3UTC6Ja}482Y9R<=b5APv#KgvNE6d9@FsY%XuUIdwF5nM+P=8Eo$H8(C!S7Kq3PD% zG)yYTq&$VnXJzSz1@D&~UxG3BOi)GO3rU--#lr#2w8AVqOT>nr9|k|9uV>6@RnOciUepJ3N8%I{V$(=VWWgOq5eV3tupx{oRP1su$2IgXoX$HVR!AtqWW})7Zbh zuMbtZ??=;oba)MRRg<}8@+q9bK&TaR~>9MkVcY_qfq6(?B*9z3hKq1C?OO^ZFtU;2oHru#joU`lPrYNm!hCz?Kp(C1Jb0Za?e^f zH{#|}(L&r2tsxfeNHi~$AQ-4WU5lBLj#?aOd8miC}XX=gix@m82VN|t1>Sn!$%K0u#DN04O*f}382G)k>S$ekU6E_2VYg18_j zLy(Z1x_%14LIHiV*j6N39K1Y9Bg2StUB^w`2hH)d#3^g>q+eULeGrf!!;(b;vISgY zx_yp#xVSY>i&O-X8x3;Xo$jowG};Hv{1ZT9Fu=~wEluXclUmhK<$#XY|RbQ(e#ubl|4i0w-7+%pU{ZHFj+bG+pN^VqY#L>Cld z*b(awadut>kRXjaXkkR^7g=9DSk} zFi%IVeJ}@{_dvJlB(>4BwAB$G(r@x^!YaxkVW8}(rK&h5yQ4&s@tz|Qwf&m{Vpo7j z!V6n3##;uG#TFlwH;q)Q)blEHDzi5v)Vl1i-S3eeNL{2d!ida_XPi zrvFX>o!_*SVTH)lHR?TBfvaSM@5^)_SxDX$>JGpolgKgvElFxKho)IbZ*h?EX5-j3 zG|P;H+Ce3K`s!m}n>XZDP1sF4ld&~K$!yD91Tl7uQuc%;Ky zrH`>IrtR~>2pvLUR3kqQ#R5!hnpQPcr3Z`5wnnmYUuFVFhXzsmxSzP2Kz77|c?uKB zrjT;^vRsM#rZB!)-hR9IxT43eM>_k|1%^-CDImftg0_Ziw`7oI=9shyIIb+DZH*#tyP*jQ&=M_)5DCE2L*4zl6X{C1TW6`~4G`aoAT2hr zd3?2C^XYiC4Y=bR0#SUTvt3qNvuwQ=nTidEQWr0k^V{wO<%1;x>l_Uc9f_;?J|_Rf z77jBPt&K;F=+5QdpbrwX=+>@L+seQzOj1Cdr$SZ3yI1_7v>TS0CZ9J(y1!f9xaqb3 zv}fTeSia==In_fMR9h+meRAeGIr;*st$zfyCN_UK$Q?cWX$M^1mBo`!a)38H zOkl}R%yMNGZaE6l{!xQ*1y9{N$iSjppcjzCn6nU3%e7%lwD82w(7G@0B;ayli*ydc zP0F@v3#j;PX+X3TzKMs%AZqGS+&YDZDOYrKhr5$0piqmC@Zy6F;+1($0A-m%=(H<{ zYBH3^i`{;UiuLIchm2!Q?)Ef@02U+|9E`(kp}_6z%Bz%DoM7d*~w4$ zCO!9dZ#+rTym>&Cax^!*KF?J;M9b-EnPS@8D35e$Y#fl(f;}zxB)7#bCrpg}U85?ar5nY{7yStCed9M8KpYE;?ey{lU!1&JzZ+p=F!ePOBkq zc=yBcPnjK$n5uHvo?f#^y%A*0mwQB`1sZ%{z`u|g5<|3ztm28z+TKw!Hr34)#6JD3 zGB^}z8N@hL1nko!Zt%rapu@S9J*2}EaO26a4Mh|_Svd_!DuzT9TYWgFW zNO{26Dx`tmumAeac4^@+`DB|8e%>NF=UTVU@oN96kOEDI1g|b##W3x_^7Itcf_AB& zt@5ayBH0)dMlb)_va2N#C)r8?b$2h{2FsMiYlKQ2kzu*1&eh#f{UE;3UtjB~lf{fC z!v)O*gP4xwGlz2`t{k<_e6zlgw`(=3-)cKtB+c4p0MW2VyO)Db9}mLW@BI3Rkp1SJ z(@eVCugWfA<<&~wgt{1r?%JNA{wG9`b%8*V77*E%uCBW&EIRDZGwdgC96&1P8)K4U zjG0Mk^zSwX4%?I6c$0xcOXfRU96Q;1la~4xPO?qmT5`iXxwjU)&d=Uf&CGe_#M9cv z?PbkA3_8A8jBgT~6WA5{`jM8T%GgKn6-y<%)t7}AHz*CYByizQONp;>iLpf3Myr`wLs zzB8%MYbAdBG$h@78t%lEhBBmPLIoK3s;I7s1j9=;> zXIaD59P2nDW`=Yyg&}r#+i2jO+fmHx786kR>c{G4wKd_LB!+f=!{yxl{Y~q`NQBiQ zCaPJ0h$lmc`KOseJNYN+C_%3C={&#tm#;dmH+XvgGT>f?0aw@F0PQAYyU4L}*U$LL zS`EgOosY-KRZd^bmN{Rilpm_K4FR&dZ7A+p*Gd3$tB!!I{c0u0zx6DjI_>5SwvXJ7 zKSJ3wur*KMl?RVc^7ogZuWV@vwoNQfhJSAO0S-|Mf^|N7x92*r%d4^1HecuPkuO!& z8I~gyPA4vWxSn`Nfp5RqmhJm^qvXPgh84l79F}jRmYA)r+V@C!Z3g)u4EN9R#lrhp z2TCZ{Qdo;T5!*HG)uzpwpxr)QP$4$@{3(mrk)Bya+A7@_?F98x+>GToc=sf|1^2)w zsX;KJ453HxFDNaAqb<@{;F80CdXjl2v|e_0ID80L-k4Q?!E1R?mUwjtr)DmG(tfHE zRV}-j*15~7)ou%y`Ux`EV3JHQTs?P1@d)1>b`p(o$%&n(OSd8H#!r|y5y5eucB5z^ zBIoyhrLAZB79tigfIM;v52e-hW?oOM@3!tcr-v$(n3m7X^C>nWO4*ED=(k&{lR!F{ za|Ae*2P#7@!=-K&!9TiM1f&}|;~5%_N*!hkq`cd+C%MZR(ghI4wf__`o^sfWeQ?f7 zd4tk?fKQN2^AcL)ITXGXBs2HI_^J=9SI^;(bAC*P9T2g#R56*9#nurotyioT6p&Nl zz|0GY6!RN7v0NK-JkXdqDZ==sr5n$2(`3Yk>hz 0 and text_index[idx - 1] == text_index[idx]: + continue + char_list.append(self.character[text_index[idx]]) + text = ''.join(char_list) + return text + + def get_char_num(self): + return len(self.character) + + def get_beg_end_flag_idx(self, beg_or_end): + if self.loss_type == "attention": + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx"\ + % beg_or_end + return idx + else: + err = "error in get_beg_end_flag_idx when using the loss %s"\ + % (self.loss_type) + assert False, err + + +class OCRReader(object): + def __init__(self): + args = self.parse_args() + image_shape = [int(v) for v in args.rec_image_shape.split(",")] + self.rec_image_shape = image_shape + self.character_type = args.rec_char_type + self.rec_batch_num = args.rec_batch_num + char_ops_params = {} + char_ops_params["character_type"] = args.rec_char_type + char_ops_params["character_dict_path"] = args.rec_char_dict_path + char_ops_params['loss_type'] = 'ctc' + self.char_ops = CharacterOps(char_ops_params) + + def parse_args(self): + parser = argparse.ArgumentParser() + parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_model_dir", type=str) + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") + parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=1) + parser.add_argument( + "--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt") + return parser.parse_args() + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.character_type == "ch": + imgW = int(32 * max_wh_ratio) + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def preprocess(self, img_list): + img_num = len(img_list) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(img_num): + h, w = img_list[ino].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(img_num): + norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + + return norm_img_batch[0] + + def postprocess(self, outputs): + rec_res = [] + rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"] + predict_lod = outputs["softmax_0.tmp_0.lod"] + rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"] + for rno in range(len(rec_idx_lod) - 1): + beg = rec_idx_lod[rno] + end = rec_idx_lod[rno + 1] + rec_idx_tmp = rec_idx_batch[beg:end, 0] + preds_text = self.char_ops.decode(rec_idx_tmp) + beg = predict_lod[rno] + end = predict_lod[rno + 1] + probs = outputs["softmax_0.tmp_0"][beg:end, :] + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + score = np.mean(probs[valid_ind, ind[valid_ind]]) + rec_res.append([preds_text, score]) + return rec_res -- GitLab