From f800a6e859f2e20fde4fdc1fc6f57e6ab3ad5432 Mon Sep 17 00:00:00 2001 From: wangjiawei04 Date: Tue, 30 Jun 2020 15:40:13 +0800 Subject: [PATCH] code style fix --- python/examples/ocr/README.md | 26 ++- python/examples/ocr/ocr_rpc_client.py | 193 ++++++++++++++++++ python/examples/ocr/ocr_web_client.sh | 1 + python/examples/ocr/ocr_web_server.py | 158 ++++++++++++++ python/examples/ocr/test_ocr_rec_client.py | 31 --- python/examples/ocr/test_rec.jpg | Bin 6369 -> 0 bytes .../paddle_serving_app/models/model_list.py | 2 +- .../paddle_serving_app/reader/ocr_reader.py | 26 ++- 8 files changed, 393 insertions(+), 44 deletions(-) create mode 100644 python/examples/ocr/ocr_rpc_client.py create mode 100644 python/examples/ocr/ocr_web_client.sh create mode 100644 python/examples/ocr/ocr_web_server.py delete mode 100644 python/examples/ocr/test_ocr_rec_client.py delete mode 100644 python/examples/ocr/test_rec.jpg diff --git a/python/examples/ocr/README.md b/python/examples/ocr/README.md index 04c4fd3e..3535ed80 100644 --- a/python/examples/ocr/README.md +++ b/python/examples/ocr/README.md @@ -4,18 +4,42 @@ ``` python -m paddle_serving_app.package --get_model ocr_rec tar -xzvf ocr_rec.tar.gz +python -m paddle_serving_app.package --get_model ocr_det +tar -xzvf ocr_det.tar.gz ``` ## RPC Service ### Start Service +For the following two code block, please check your devices and pick one +for GPU device +``` +python -m paddle_serving_server_gpu.serve --model ocr_rec_model --port 9292 --gpu_id 0 +python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +``` +for CPU device ``` python -m paddle_serving_server.serve --model ocr_rec_model --port 9292 +python -m paddle_serving_server.serve --model ocr_det_model --port 9293 ``` ### Client Prediction ``` -python test_ocr_rec_client.py +python ocr_rpc_client.py +``` + +## Web Service + +### Start Service + +``` +python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +python ocr_web_server.py +``` + +### Client Prediction +``` +sh ocr_web_client.sh ``` diff --git a/python/examples/ocr/ocr_rpc_client.py b/python/examples/ocr/ocr_rpc_client.py new file mode 100644 index 00000000..212d46c2 --- /dev/null +++ b/python/examples/ocr/ocr_rpc_client.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes +import time +import re + + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + +def get_rotate_crop_image(img, points): + #img = cv2.imread(img) + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + img_crop_width = int(np.linalg.norm(points[0] - points[1])) + img_crop_height = int(np.linalg.norm(points[0] - points[3])) + pts_std = np.float32([[0, 0], [img_crop_width, 0], \ + [img_crop_width, img_crop_height], [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img_crop, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def read_det_box_file(filename): + with open(filename, 'r') as f: + line = f.readline() + a, b, c = int(line.split(' ')[0]), int(line.split(' ')[1]), int( + line.split(' ')[2]) + dt_boxes = np.zeros((a, b, c)).astype(np.float32) + line = f.readline() + for i in range(a): + for j in range(b): + line = f.readline() + dt_boxes[i, j, 0], dt_boxes[i, j, 1] = float( + line.split(' ')[0]), float(line.split(' ')[1]) + line = f.readline() + + +def resize_norm_img(img, max_wh_ratio): + import math + imgC, imgH, imgW = 3, 32, 320 + imgW = int(32 * max_wh_ratio) + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + +def main(): + client1 = Client() + client1.load_client_config("ocr_det_client/serving_client_conf.prototxt") + client1.connect(["127.0.0.1:9293"]) + + client2 = Client() + client2.load_client_config("ocr_rec_client/serving_client_conf.prototxt") + client2.connect(["127.0.0.1:9292"]) + + read_image_file = File2Image() + preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + + filter_func = FilterBoxes(10, 10) + ocr_reader = OCRReader() + files = [ + "./imgs/{}".format(f) for f in os.listdir('./imgs') + if re.match(r'[0-9]+.*\.jpg|[0-9]+.*\.png', f) + ] + #files = ["2.jpg"]*30 + #files = ["rctw/rctw/train/images/image_{}.jpg".format(i) for i in range(500)] + time_all = 0 + time_det_all = 0 + time_rec_all = 0 + for name in files: + #print(name) + im = read_image_file(name) + ori_h, ori_w, _ = im.shape + time1 = time.time() + img = preprocess(im) + _, new_h, new_w = img.shape + ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] + #print(new_h, new_w, ori_h, ori_w) + time_before_det = time.time() + outputs = client1.predict(feed={"image": img}, fetch=["concat_1.tmp_0"]) + time_after_det = time.time() + time_det_all += (time_after_det - time_before_det) + #print(outputs) + dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list]) + dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) + dt_boxes = sorted_boxes(dt_boxes) + feed_list = [] + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for img in img_list: + norm_img = resize_norm_img(img, max_wh_ratio) + #norm_img = norm_img[np.newaxis, :] + feed = {"image": norm_img} + feed_list.append(feed) + #fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + fetch = ["ctc_greedy_decoder_0.tmp_0"] + time_before_rec = time.time() + if len(feed_list) == 0: + continue + fetch_map = client2.predict(feed=feed_list, fetch=fetch) + time_after_rec = time.time() + time_rec_all += (time_after_rec - time_before_rec) + rec_res = ocr_reader.postprocess(fetch_map) + #for res in rec_res: + # print(res[0].encode("utf-8")) + time2 = time.time() + time_all += (time2 - time1) + print("rpc+det time: {}".format(time_all / len(files))) + + +if __name__ == '__main__': + main() diff --git a/python/examples/ocr/ocr_web_client.sh b/python/examples/ocr/ocr_web_client.sh new file mode 100644 index 00000000..5f4f1d7d --- /dev/null +++ b/python/examples/ocr/ocr_web_client.sh @@ -0,0 +1 @@ + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/others/1.jpg"}], "fetch": ["res"]}' http://127.0.0.1:9292/ocr/prediction diff --git a/python/examples/ocr/ocr_web_server.py b/python/examples/ocr/ocr_web_server.py new file mode 100644 index 00000000..b55027d8 --- /dev/null +++ b/python/examples/ocr/ocr_web_server.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes +from paddle_serving_server_gpu.web_service import WebService +import time +import re + + +class OCRService(WebService): + def init_det_client(self, det_port, det_client_config): + self.det_preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.det_client = Client() + self.det_client.load_client_config(det_client_config) + self.det_client.connect(["127.0.0.1:{}".format(det_port)]) + + def preprocess(self, feed=[], fetch=[]): + img_url = feed[0]["image"] + #print(feed, img_url) + read_from_url = URL2Image() + im = read_from_url(img_url) + ori_h, ori_w, _ = im.shape + det_img = self.det_preprocess(im) + #print("det_img", det_img, det_img.shape) + det_out = self.det_client.predict( + feed={"image": det_img}, fetch=["concat_1.tmp_0"]) + + #print("det_out", det_out) + def sorted_boxes(dt_boxes): + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + def get_rotate_crop_image(img, points): + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + img_crop_width = int(np.linalg.norm(points[0] - points[1])) + img_crop_height = int(np.linalg.norm(points[0] - points[3])) + pts_std = np.float32([[0, 0], [img_crop_width, 0], \ + [img_crop_width, img_crop_height], [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img_crop, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + def resize_norm_img(img, max_wh_ratio): + import math + imgC, imgH, imgW = 3, 32, 320 + imgW = int(32 * max_wh_ratio) + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + _, new_h, new_w = det_img.shape + filter_func = FilterBoxes(10, 10) + post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] + dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) + dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) + dt_boxes = sorted_boxes(dt_boxes) + feed_list = [] + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for img in img_list: + norm_img = resize_norm_img(img, max_wh_ratio) + feed = {"image": norm_img} + feed_list.append(feed) + fetch = ["ctc_greedy_decoder_0.tmp_0"] + #print("feed_list", feed_list) + return feed_list, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + #print(fetch_map) + ocr_reader = OCRReader() + rec_res = ocr_reader.postprocess(fetch_map) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + fetch_map["res"] = res_lst + del fetch_map["ctc_greedy_decoder_0.tmp_0"] + del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"] + return fetch_map + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_rec_model") +ocr_service.prepare_server(workdir="workdir", port=9292) +ocr_service.init_det_client( + det_port=9293, + det_client_config="ocr_det_client/serving_client_conf.prototxt") +ocr_service.run_rpc_service() +ocr_service.run_web_service() diff --git a/python/examples/ocr/test_ocr_rec_client.py b/python/examples/ocr/test_ocr_rec_client.py deleted file mode 100644 index b61256d0..00000000 --- a/python/examples/ocr/test_ocr_rec_client.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle_serving_client import Client -from paddle_serving_app.reader import OCRReader -import cv2 - -client = Client() -client.load_client_config("ocr_rec_client/serving_client_conf.prototxt") -client.connect(["127.0.0.1:9292"]) - -image_file_list = ["./test_rec.jpg"] -img = cv2.imread(image_file_list[0]) -ocr_reader = OCRReader() -feed = {"image": ocr_reader.preprocess([img])} -fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] -fetch_map = client.predict(feed=feed, fetch=fetch) -rec_res = ocr_reader.postprocess(fetch_map) -print(image_file_list[0]) -print(rec_res[0][0]) diff --git a/python/examples/ocr/test_rec.jpg b/python/examples/ocr/test_rec.jpg deleted file mode 100644 index 2c34cd33eac5766a072fde041fa6c9b1d612f1db..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6369 zcmbVwcT^MIw|0;wMG%lCC3Fx(1O${41f{E35R?{?4pKu25E43y6oFR|5m2d0iFA+> zdPk&&CXmp3A_xNs=L28MqJBhy(1MkYoEhO;bZnVA0?^^BE; znf33*--rC?)-&|9wDio342=JV{I49P1;lX{bOnT?r4a+2;h>@AprLetz*L;{|H7k! z{g2U{p{1i|V5Guep$1f+r^2VDrGloXBBn-%Qs+T*9Q2%`iaHEj59}Dld@d-3Cx1FC zu3O&9{cxBdq5Rx8f{B@jmyiFVq?EMGCE05#s%qD7sNcG+r*B|rWcEhGmpeIoFwrt zwSSrY&xl3*pP2nWV*j7lG>Da!hMGKD4v;noc+V339SABAcn13lZ!xYv6<_rnT8BD5 zMRuWR$R;=m3{kPMqMj(o5n7aRyGt_pLGT%$Hz4#IO}V^l@15J(*<_byp+_mwhwr+& zwKLiKn)Xhai$6m*C?J+*Avokm-*5f9w+B{QG&DoZu6Tb4@3t(`*pkcbz_^y$w_)g- zHQ<=G8nqOas*ChZ*&nuM4jrCF{L;V%k;H+!7c%JJ3C^}KMCwWaW=v{3O3Z=?0^dkEhb8wtB$t83h!DGKF0@ z{3%pg+{-8VZksfmP{+SHQB?`N(tHC9Yd*py2>_K3qi!D_m#79reF%EVogNU%wz0M2 zKd9T?r9L=|S!XGLt2~q9ih2AT49VhbM74dNwM=RwWroCFCc-O}z0qXwHQ8bt3g~4u z{E9u0pp}dIO8DRstGPqch|dWf9C_tw-;-F__DDy#lyG=AMi=v(khdTV zKK(~LLdEIZAdgZBR(eIaQQ?+tE(JuNWXdNUi4fdRUN_&q?=|M*8n()`$KrUk{fynOM5!H_HmIvavr@Ud^%yM2PC!dhYQ-FavR-l9DkZV0#Qq^0tX2~mZV5$~v>HbR@5&=|_ zxM$ZPd_=EKOw>FVZwKc(k}K^7LJ?LN#LFR8?(DNyE@6MY@I#gHipcR=N4d&E(3n* zc^%BBo!=N(O#PWv(8PYQ4T$3urbgxshpV7$Zr*OTm#ZEV&;9CU7p&zl%+vI}Vpec{ z=0oD2rBfQB;GED;h``1{5J@jK=~KN=_Sa;S7_t#R`e)8e=D`!<>b4CIb(R8Z`#w=M z6XX}t>sbbU@aAhTYjQ+`BKD21oJ^U${#IF#Gnq@lrmn6YNH4a$)5~2Q{WXShSwqCX zOPdC&fb0?G9fjODo%}x3juvpU$Te-=dyrrKCMcZc#*xusHxBBF?1ol=yN0HU+l8-= zRwuypn(B!xsUByNXYxuaB~Ce%2snd0Ntd3LkzvE3+H8FxtUupDpjqsoF^t~BkN z$FYA@7J>zBog=qcQ;j5QAR;cfAnWocx8h}nZZd)bxVHnaCKM1;V&R3PH*>76-xyFa z0otFj@`ic)f?oY!8WvusARaFiW3I)orL=v_&PY){kz1qoj8|(;6h@#~00=*oUy~w{ z*lm=jtOe3vcGN#Nk}4l#uj(w(&dWVe3gsN}^J*Kr+&{ZLOF|1(eKEaO`SV2khM zZG6Ay54=#~qp8L+5*Z5K-?RLZxySGq`TgK zF%31zcwm6#y8l*XTmRA5+h4q&W9?_6`REWIv$Zz&cy&fbab!yyY;^UFEhmUv8q1o; zj&{ILCNp|jIPEU>(>&r1P*?zkry7qoIwzIuoEv|Yrs7f5rti4fbENlTnO^Ri@SX)^ zk;qHt>R%}M6vVw7(=4))rj}4Bq<%@Wqwa2rf!UCRFiQPMZ4ul?0r4USP+(1Ef^+BL z&v_OPq1|hPw!e%!QC=-dGMZl)m*r-Dmkv$EE;@IjxRF1z*$IjHc#PV;@kLk5v!6L~ z6fb-eSsk+op(>jW5fz5eab3{nbb*LfG-1gas8H8X`pTme9d}LHo072b6@w38>#@O6 zpNiZ0k`i!fvP4>Q;oj#)jZ+@g`jd6kZn>T>N%_$}>>%Xi4T7;0U}=GJHJYHUf7!ZD zn<-6CM=s}i4#yuZ#Zf=7F6`4LIE$@AVt-o==De3D&b>+YfeYv8puxkX3(}Q76O9%N z3UTC6Ja}482Y9R<=b5APv#KgvNE6d9@FsY%XuUIdwF5nM+P=8Eo$H8(C!S7Kq3PD% zG)yYTq&$VnXJzSz1@D&~UxG3BOi)GO3rU--#lr#2w8AVqOT>nr9|k|9uV>6@RnOciUepJ3N8%I{V$(=VWWgOq5eV3tupx{oRP1su$2IgXoX$HVR!AtqWW})7Zbh zuMbtZ??=;oba)MRRg<}8@+q9bK&TaR~>9MkVcY_qfq6(?B*9z3hKq1C?OO^ZFtU;2oHru#joU`lPrYNm!hCz?Kp(C1Jb0Za?e^f zH{#|}(L&r2tsxfeNHi~$AQ-4WU5lBLj#?aOd8miC}XX=gix@m82VN|t1>Sn!$%K0u#DN04O*f}382G)k>S$ekU6E_2VYg18_j zLy(Z1x_%14LIHiV*j6N39K1Y9Bg2StUB^w`2hH)d#3^g>q+eULeGrf!!;(b;vISgY zx_yp#xVSY>i&O-X8x3;Xo$jowG};Hv{1ZT9Fu=~wEluXclUmhK<$#XY|RbQ(e#ubl|4i0w-7+%pU{ZHFj+bG+pN^VqY#L>Cld z*b(awadut>kRXjaXkkR^7g=9DSk} zFi%IVeJ}@{_dvJlB(>4BwAB$G(r@x^!YaxkVW8}(rK&h5yQ4&s@tz|Qwf&m{Vpo7j z!V6n3##;uG#TFlwH;q)Q)blEHDzi5v)Vl1i-S3eeNL{2d!ida_XPi zrvFX>o!_*SVTH)lHR?TBfvaSM@5^)_SxDX$>JGpolgKgvElFxKho)IbZ*h?EX5-j3 zG|P;H+Ce3K`s!m}n>XZDP1sF4ld&~K$!yD91Tl7uQuc%;Ky zrH`>IrtR~>2pvLUR3kqQ#R5!hnpQPcr3Z`5wnnmYUuFVFhXzsmxSzP2Kz77|c?uKB zrjT;^vRsM#rZB!)-hR9IxT43eM>_k|1%^-CDImftg0_Ziw`7oI=9shyIIb+DZH*#tyP*jQ&=M_)5DCE2L*4zl6X{C1TW6`~4G`aoAT2hr zd3?2C^XYiC4Y=bR0#SUTvt3qNvuwQ=nTidEQWr0k^V{wO<%1;x>l_Uc9f_;?J|_Rf z77jBPt&K;F=+5QdpbrwX=+>@L+seQzOj1Cdr$SZ3yI1_7v>TS0CZ9J(y1!f9xaqb3 zv}fTeSia==In_fMR9h+meRAeGIr;*st$zfyCN_UK$Q?cWX$M^1mBo`!a)38H zOkl}R%yMNGZaE6l{!xQ*1y9{N$iSjppcjzCn6nU3%e7%lwD82w(7G@0B;ayli*ydc zP0F@v3#j;PX+X3TzKMs%AZqGS+&YDZDOYrKhr5$0piqmC@Zy6F;+1($0A-m%=(H<{ zYBH3^i`{;UiuLIchm2!Q?)Ef@02U+|9E`(kp}_6z%Bz%DoM7d*~w4$ zCO!9dZ#+rTym>&Cax^!*KF?J;M9b-EnPS@8D35e$Y#fl(f;}zxB)7#bCrpg}U85?ar5nY{7yStCed9M8KpYE;?ey{lU!1&JzZ+p=F!ePOBkq zc=yBcPnjK$n5uHvo?f#^y%A*0mwQB`1sZ%{z`u|g5<|3ztm28z+TKw!Hr34)#6JD3 zGB^}z8N@hL1nko!Zt%rapu@S9J*2}EaO26a4Mh|_Svd_!DuzT9TYWgFW zNO{26Dx`tmumAeac4^@+`DB|8e%>NF=UTVU@oN96kOEDI1g|b##W3x_^7Itcf_AB& zt@5ayBH0)dMlb)_va2N#C)r8?b$2h{2FsMiYlKQ2kzu*1&eh#f{UE;3UtjB~lf{fC z!v)O*gP4xwGlz2`t{k<_e6zlgw`(=3-)cKtB+c4p0MW2VyO)Db9}mLW@BI3Rkp1SJ z(@eVCugWfA<<&~wgt{1r?%JNA{wG9`b%8*V77*E%uCBW&EIRDZGwdgC96&1P8)K4U zjG0Mk^zSwX4%?I6c$0xcOXfRU96Q;1la~4xPO?qmT5`iXxwjU)&d=Uf&CGe_#M9cv z?PbkA3_8A8jBgT~6WA5{`jM8T%GgKn6-y<%)t7}AHz*CYByizQONp;>iLpf3Myr`wLs zzB8%MYbAdBG$h@78t%lEhBBmPLIoK3s;I7s1j9=;> zXIaD59P2nDW`=Yyg&}r#+i2jO+fmHx786kR>c{G4wKd_LB!+f=!{yxl{Y~q`NQBiQ zCaPJ0h$lmc`KOseJNYN+C_%3C={&#tm#;dmH+XvgGT>f?0aw@F0PQAYyU4L}*U$LL zS`EgOosY-KRZd^bmN{Rilpm_K4FR&dZ7A+p*Gd3$tB!!I{c0u0zx6DjI_>5SwvXJ7 zKSJ3wur*KMl?RVc^7ogZuWV@vwoNQfhJSAO0S-|Mf^|N7x92*r%d4^1HecuPkuO!& z8I~gyPA4vWxSn`Nfp5RqmhJm^qvXPgh84l79F}jRmYA)r+V@C!Z3g)u4EN9R#lrhp z2TCZ{Qdo;T5!*HG)uzpwpxr)QP$4$@{3(mrk)Bya+A7@_?F98x+>GToc=sf|1^2)w zsX;KJ453HxFDNaAqb<@{;F80CdXjl2v|e_0ID80L-k4Q?!E1R?mUwjtr)DmG(tfHE zRV}-j*15~7)ou%y`Ux`EV3JHQTs?P1@d)1>b`p(o$%&n(OSd8H#!r|y5y5eucB5z^ zBIoyhrLAZB79tigfIM;v52e-hW?oOM@3!tcr-v$(n3m7X^C>nWO4*ED=(k&{lR!F{ za|Ae*2P#7@!=-K&!9TiM1f&}|;~5%_N*!hkq`cd+C%MZR(ghI4wf__`o^sfWeQ?f7 zd4tk?fKQN2^AcL)ITXGXBs2HI_^J=9SI^;(bAC*P9T2g#R56*9#nurotyioT6p&Nl zz|0GY6!RN7v0NK-JkXdqDZ==sr5n$2(`3Yk>hz